Parcourir la source

Allow varlen_fwd to take optional seqused_k (#647)

Co-authored-by: bottler <bottler@users.noreply.github.com>
Jeremy Reizenstein il y a 1 an
Parent
commit
ce3e7280f8

+ 14 - 0
csrc/flash_attn/flash_api.cpp

@@ -36,6 +36,7 @@ void set_params_fprop(Flash_fwd_params &params,
                       at::Tensor out,
                       void *cu_seqlens_q_d,
                       void *cu_seqlens_k_d,
+                      void *seqused_k,
                       void *p_d,
                       void *softmax_lse_d,
                       float p_dropout,
@@ -72,6 +73,7 @@ void set_params_fprop(Flash_fwd_params &params,
 
     params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
     params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
+    params.seqused_k = static_cast<int *>(seqused_k);
 
     // P = softmax(QK^T)
     params.p_ptr = p_d;
@@ -156,6 +158,7 @@ void set_params_dgrad(Flash_bwd_params &params,
                      cu_seqlens_q_d,
                      cu_seqlens_k_d,
                      nullptr,
+                     nullptr,
                      softmax_lse_d,
                      p_dropout,
                      softmax_scale,
@@ -363,6 +366,7 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
                      q_padded, k_padded, v_padded, out,
                      /*cu_seqlens_q_d=*/nullptr,
                      /*cu_seqlens_k_d=*/nullptr,
+                     /*seqused_k=*/nullptr,
                      return_softmax ? p.data_ptr() : nullptr,
                      softmax_lse.data_ptr(),
                      p_dropout,
@@ -436,6 +440,7 @@ mha_varlen_fwd(const at::Tensor &q,  // total_q x num_heads x head_size, total_q
                c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
                const at::Tensor &cu_seqlens_q,  // b+1
                const at::Tensor &cu_seqlens_k,  // b+1
+               c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
                const int max_seqlen_q,
                const int max_seqlen_k,
                const float p_dropout,
@@ -494,6 +499,13 @@ mha_varlen_fwd(const at::Tensor &q,  // total_q x num_heads x head_size, total_q
     CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
     CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
     CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
+    if (seqused_k.has_value()){
+        auto seqused_k_ = seqused_k.value();
+        TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
+        TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
+        TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
+        CHECK_SHAPE(seqused_k_, batch_size);
+    }
 
     at::Tensor q_padded, k_padded, v_padded;
     if (head_size_og % 8 != 0) {
@@ -554,6 +566,7 @@ mha_varlen_fwd(const at::Tensor &q,  // total_q x num_heads x head_size, total_q
                      q_padded, k_padded, v_padded, out,
                      cu_seqlens_q.data_ptr(),
                      cu_seqlens_k.data_ptr(),
+                     seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
                      return_softmax ? p.data_ptr() : nullptr,
                      softmax_lse.data_ptr(),
                      p_dropout,
@@ -1167,6 +1180,7 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
                      q_padded, kcache_padded, vcache_padded, out,
                      /*cu_seqlens_q_d=*/nullptr,
                      /*cu_seqlens_k_d=*/nullptr,
+                     /*seqused_k=*/nullptr,
                      /*p_ptr=*/nullptr,
                      softmax_lse.data_ptr(),
                      /*p_dropout=*/0.f,

+ 1 - 1
csrc/flash_attn/src/block_info.h

@@ -19,7 +19,7 @@ struct BlockInfo {
         // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
         // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
         , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
-        , actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
+        , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
         {
         }
 

+ 3 - 0
csrc/flash_attn/src/flash.h

@@ -77,6 +77,9 @@ struct Flash_fwd_params : public Qkv_params {
     int * __restrict__ cu_seqlens_q;
     int * __restrict__ cu_seqlens_k;
 
+    // If provided, the actual length of each k sequence.
+    int * __restrict__ seqused_k;
+
     int *__restrict__ blockmask;
 
     // The K_new and V_new matrices.

+ 1 - 0
flash_attn/flash_attn_interface.py

@@ -83,6 +83,7 @@ def _flash_attn_varlen_forward(
         None,
         cu_seqlens_q,
         cu_seqlens_k,
+        None,
         max_seqlen_q,
         max_seqlen_k,
         dropout_p,