|
@@ -36,6 +36,7 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
|
|
at::Tensor out,
|
|
|
void *cu_seqlens_q_d,
|
|
|
void *cu_seqlens_k_d,
|
|
|
+ void *seqused_k,
|
|
|
void *p_d,
|
|
|
void *softmax_lse_d,
|
|
|
float p_dropout,
|
|
@@ -72,6 +73,7 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
|
|
|
|
|
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
|
|
|
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
|
|
|
+ params.seqused_k = static_cast<int *>(seqused_k);
|
|
|
|
|
|
// P = softmax(QK^T)
|
|
|
params.p_ptr = p_d;
|
|
@@ -156,6 +158,7 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
|
|
|
cu_seqlens_q_d,
|
|
|
cu_seqlens_k_d,
|
|
|
nullptr,
|
|
|
+ nullptr,
|
|
|
softmax_lse_d,
|
|
|
p_dropout,
|
|
|
softmax_scale,
|
|
@@ -363,6 +366,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
|
|
q_padded, k_padded, v_padded, out,
|
|
|
/*cu_seqlens_q_d=*/nullptr,
|
|
|
/*cu_seqlens_k_d=*/nullptr,
|
|
|
+ /*seqused_k=*/nullptr,
|
|
|
return_softmax ? p.data_ptr() : nullptr,
|
|
|
softmax_lse.data_ptr(),
|
|
|
p_dropout,
|
|
@@ -436,6 +440,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
|
|
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
|
|
const at::Tensor &cu_seqlens_q, // b+1
|
|
|
const at::Tensor &cu_seqlens_k, // b+1
|
|
|
+ c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
|
|
const int max_seqlen_q,
|
|
|
const int max_seqlen_k,
|
|
|
const float p_dropout,
|
|
@@ -494,6 +499,13 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
|
|
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
|
|
|
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
|
|
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
|
|
+ if (seqused_k.has_value()){
|
|
|
+ auto seqused_k_ = seqused_k.value();
|
|
|
+ TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
|
|
|
+ TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
|
|
|
+ TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
|
|
|
+ CHECK_SHAPE(seqused_k_, batch_size);
|
|
|
+ }
|
|
|
|
|
|
at::Tensor q_padded, k_padded, v_padded;
|
|
|
if (head_size_og % 8 != 0) {
|
|
@@ -554,6 +566,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
|
|
q_padded, k_padded, v_padded, out,
|
|
|
cu_seqlens_q.data_ptr(),
|
|
|
cu_seqlens_k.data_ptr(),
|
|
|
+ seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
|
|
|
return_softmax ? p.data_ptr() : nullptr,
|
|
|
softmax_lse.data_ptr(),
|
|
|
p_dropout,
|
|
@@ -1167,6 +1180,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
|
|
q_padded, kcache_padded, vcache_padded, out,
|
|
|
/*cu_seqlens_q_d=*/nullptr,
|
|
|
/*cu_seqlens_k_d=*/nullptr,
|
|
|
+ /*seqused_k=*/nullptr,
|
|
|
/*p_ptr=*/nullptr,
|
|
|
softmax_lse.data_ptr(),
|
|
|
/*p_dropout=*/0.f,
|