|
@@ -36,6 +36,7 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
|
|
at::Tensor out,
|
|
|
void *cu_seqlens_q_d,
|
|
|
void *cu_seqlens_k_d,
|
|
|
+ void *seqused_q,
|
|
|
void *seqused_k,
|
|
|
void *p_d,
|
|
|
void *softmax_lse_d,
|
|
@@ -80,6 +81,7 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
|
|
|
|
|
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
|
|
|
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
|
|
|
+ params.seqused_q = static_cast<int *>(seqused_q);
|
|
|
params.seqused_k = static_cast<int *>(seqused_k);
|
|
|
|
|
|
TORCH_CHECK(
|
|
@@ -171,6 +173,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
|
|
|
at::Tensor dv,
|
|
|
void *cu_seqlens_q_d,
|
|
|
void *cu_seqlens_k_d,
|
|
|
+ void *seqused_q,
|
|
|
+ void *seqused_k,
|
|
|
void *dq_accum_d,
|
|
|
void *dk_accum_d,
|
|
|
void *dv_accum_d,
|
|
@@ -187,7 +191,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
|
|
|
q, k, v, out,
|
|
|
cu_seqlens_q_d,
|
|
|
cu_seqlens_k_d,
|
|
|
- nullptr,
|
|
|
+ seqused_q,
|
|
|
+ seqused_k,
|
|
|
nullptr,
|
|
|
softmax_lse_d,
|
|
|
p_dropout,
|
|
@@ -364,6 +369,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
|
|
q_padded, k_padded, v_padded, out,
|
|
|
/*cu_seqlens_q_d=*/nullptr,
|
|
|
/*cu_seqlens_k_d=*/nullptr,
|
|
|
+ /*seqused_q=*/nullptr,
|
|
|
/*seqused_k=*/nullptr,
|
|
|
nullptr,
|
|
|
softmax_lse.data_ptr(),
|
|
@@ -426,6 +432,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
|
|
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
|
|
const at::Tensor &cu_seqlens_q, // b+1
|
|
|
const at::Tensor &cu_seqlens_k, // b+1
|
|
|
+ c10::optional<at::Tensor> &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used.
|
|
|
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
|
|
int max_seqlen_q,
|
|
|
const int max_seqlen_k,
|
|
@@ -482,6 +489,14 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
|
|
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
|
|
|
|
|
|
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
|
|
+ if (seqused_q.has_value()){
|
|
|
+ auto seqused_q_ = seqused_q.value();
|
|
|
+ TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32");
|
|
|
+ TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device");
|
|
|
+ TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous");
|
|
|
+ CHECK_SHAPE(seqused_q_, batch_size);
|
|
|
+ }
|
|
|
+
|
|
|
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
|
|
if (seqused_k.has_value()){
|
|
|
auto seqused_k_ = seqused_k.value();
|
|
@@ -537,6 +552,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
|
|
q_padded, k_padded, v_padded, out,
|
|
|
cu_seqlens_q_d,
|
|
|
cu_seqlens_k.data_ptr(),
|
|
|
+ seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr,
|
|
|
seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
|
|
|
/*p_d=*/nullptr,
|
|
|
softmax_lse.data_ptr(),
|
|
@@ -730,8 +746,10 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
|
|
head_size, head_size_rounded,
|
|
|
q, k, v, out,
|
|
|
dout_padded, dq, dk_expanded, dv_expanded,
|
|
|
- nullptr,
|
|
|
- nullptr,
|
|
|
+ /*cu_seqlens_q_d=*/nullptr,
|
|
|
+ /*cu_seqlens_k_d=*/nullptr,
|
|
|
+ /*seqused_q=*/nullptr,
|
|
|
+ /*seqused_k=*/nullptr,
|
|
|
dq_accum.data_ptr(),
|
|
|
// loop ? dk_accum.data_ptr() : nullptr,
|
|
|
// loop ? dv_accum.data_ptr() : nullptr,
|
|
@@ -787,6 +805,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x
|
|
|
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
|
|
|
const at::Tensor &cu_seqlens_q, // b+1
|
|
|
const at::Tensor &cu_seqlens_k, // b+1
|
|
|
+ c10::optional<at::Tensor> &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used.
|
|
|
+ c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
|
|
const int max_seqlen_q,
|
|
|
const int max_seqlen_k, // max sequence length to choose the kernel
|
|
|
const float softmax_scale,
|
|
@@ -854,7 +874,22 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x
|
|
|
CHECK_SHAPE(out, total_q, num_heads, head_size);
|
|
|
CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
|
|
|
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
|
|
+ if (seqused_q.has_value()){
|
|
|
+ auto seqused_q_ = seqused_q.value();
|
|
|
+ TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32");
|
|
|
+ TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device");
|
|
|
+ TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous");
|
|
|
+ CHECK_SHAPE(seqused_q_, batch_size);
|
|
|
+ }
|
|
|
+
|
|
|
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
|
|
+ if (seqused_k.has_value()){
|
|
|
+ auto seqused_k_ = seqused_k.value();
|
|
|
+ TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
|
|
|
+ TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
|
|
|
+ TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
|
|
|
+ CHECK_SHAPE(seqused_k_, batch_size);
|
|
|
+ }
|
|
|
|
|
|
at::Tensor dq, dk, dv;
|
|
|
if (dq_.has_value()) {
|
|
@@ -927,6 +962,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x
|
|
|
dout_padded, dq, dk_expanded, dv_expanded,
|
|
|
cu_seqlens_q.data_ptr(),
|
|
|
cu_seqlens_k.data_ptr(),
|
|
|
+ seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr,
|
|
|
+ seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
|
|
|
dq_accum.data_ptr(),
|
|
|
// loop ? dk_accum.data_ptr() : nullptr,
|
|
|
// loop ? dv_accum.data_ptr() : nullptr,
|