فهرست منبع

Add split-kv and M<->H swap to varlen forward decoding attention (#754)

* Add split-k, M<->H to varseq path

* skip M<->H when dropout>0, fix LSE
Grigory Sizov 1 سال پیش
والد
کامیت
af01244ddd
1فایلهای تغییر یافته به همراه80 افزوده شده و 43 حذف شده
  1. 80 43
      csrc/flash_attn/flash_api.cpp

+ 80 - 43
csrc/flash_attn/flash_api.cpp

@@ -42,7 +42,8 @@ void set_params_fprop(Flash_fwd_params &params,
                       float p_dropout,
                       float softmax_scale,
                       int window_size_left,
-                      int window_size_right) {
+                      int window_size_right,
+                      bool seqlenq_ngroups_swapped=false) {
 
     // Reset the parameters
     memset(&params, 0, sizeof(params));
@@ -69,6 +70,10 @@ void set_params_fprop(Flash_fwd_params &params,
         params.k_batch_stride = k.stride(0);
         params.v_batch_stride = v.stride(0);
         params.o_batch_stride = out.stride(0);
+        if (seqlenq_ngroups_swapped) {
+             params.q_batch_stride *= seqlen_q;
+             params.o_batch_stride *= seqlen_q;
+        }
     }
 
     params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
@@ -251,6 +256,31 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n
     return 1;
 }
 
+void set_params_splitkv(Flash_fwd_params &params, const int batch_size,
+    const int num_heads, const int head_size,  const int max_seqlen_k, const int max_seqlen_q,
+    const int head_size_rounded,  float p_dropout, const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) {
+
+    // This needs to match with run_mha_fwd_splitkv_dispatch
+    const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
+    const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
+    // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
+    // In any case we don't expect seqlen_q to be larger than 64 for inference.
+    const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;
+    params.num_splits = num_splits;
+    if (p_dropout == 0.0f) {  // SplitKV is not implemented for dropout
+        if (num_splits < 1) {
+            params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
+        }
+        if (params.num_splits > 1) {
+            at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
+            at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
+            params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
+            params.oaccum_ptr = out_accum.data_ptr();
+        }
+        TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
+    }
+}
+
 std::vector<at::Tensor>
 mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
         const at::Tensor &k,         // batch_size x seqlen_k x num_heads_k x head_size
@@ -382,23 +412,10 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
                      window_size_left,
                      window_size_right);
 
-    // This needs to match with run_mha_fwd_splitkv_dispatch
-    const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
-    const int num_n_blocks = (seqlen_k + block_n - 1) / block_n;
-    // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
-    // In any case we don't expect seqlen_q to be larger than 64 for inference.
-    const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
-    params.num_splits = 1;
-    if (p_dropout == 0.0f) {  // SplitKV is not implemented for dropout
-        params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
-        if (params.num_splits > 1) {
-            at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
-            at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
-            params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
-            params.oaccum_ptr = out_accum.data_ptr();
-        }
-        TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
-    }
+
+    set_params_splitkv(params, batch_size, num_heads,
+                       head_size, seqlen_k, seqlen_q,
+                       head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
 
     // number of times random will be generated per thread, to offset philox counter in thc random
     // state
@@ -454,7 +471,7 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
 }
 
 std::vector<at::Tensor>
-mha_varlen_fwd(const at::Tensor &q,  // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
+mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
                const at::Tensor &k,  // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
                const at::Tensor &v,  // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
                c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
@@ -462,18 +479,17 @@ mha_varlen_fwd(const at::Tensor &q,  // total_q x num_heads x head_size, total_q
                const at::Tensor &cu_seqlens_k,  // b+1
                c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
                c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
-               const int max_seqlen_q,
+               int max_seqlen_q,
                const int max_seqlen_k,
                const float p_dropout,
                const float softmax_scale,
                const bool zero_tensors,
-               const bool is_causal,
+               bool is_causal,
                int window_size_left,
                int window_size_right,
                const bool return_softmax,
                c10::optional<at::Generator> gen_) {
 
-    if (is_causal) { window_size_right = 0; }
     auto dprops = at::cuda::getCurrentDeviceProperties();
     // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
     bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
@@ -505,12 +521,30 @@ mha_varlen_fwd(const at::Tensor &q,  // total_q x num_heads x head_size, total_q
 
     const auto sizes = q.sizes();
 
-    const int total_q = sizes[0];
     const int batch_size = cu_seqlens_q.numel() - 1;
-    const int num_heads = sizes[1];
+    int num_heads = sizes[1];
     const int head_size_og = sizes[2];
     const int total_k = k.size(0);
     const int num_heads_k = k.size(1);
+
+    if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }  // causal=true is the same as causal=false in this case
+    if (is_causal) { window_size_right = 0; }
+
+    void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
+
+    // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
+    // H/t Daniel Haziza
+    const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
+    if (seqlenq_ngroups_swapped) {
+        const int ngroups = num_heads / num_heads_k;
+        q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
+        max_seqlen_q = ngroups;
+        num_heads = num_heads_k;
+        cu_seqlens_q_d = nullptr;
+    }
+
+    const int total_q = q.sizes()[0];
+
     TORCH_CHECK(batch_size > 0, "batch size must be positive");
     TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
@@ -588,7 +622,7 @@ mha_varlen_fwd(const at::Tensor &q,  // total_q x num_heads x head_size, total_q
                      num_heads, num_heads_k,
                      head_size, head_size_rounded,
                      q_padded, k_padded, v_padded, out,
-                     cu_seqlens_q.data_ptr(),
+                     cu_seqlens_q_d,
                      cu_seqlens_k.data_ptr(),
                      seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
                      return_softmax ? p.data_ptr() : nullptr,
@@ -596,7 +630,14 @@ mha_varlen_fwd(const at::Tensor &q,  // total_q x num_heads x head_size, total_q
                      p_dropout,
                      softmax_scale,
                      window_size_left,
-                     window_size_right);
+                     window_size_right,
+                     seqlenq_ngroups_swapped);
+    if (seqlenq_ngroups_swapped) {
+        // Only apply split-k for decoding
+        set_params_splitkv(params, batch_size, num_heads,
+                       head_size, max_seqlen_k, max_seqlen_q,
+                       head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
+    }
 
     // number of times random will be generated per thread, to offset philox counter in thc random
     // state
@@ -642,6 +683,15 @@ mha_varlen_fwd(const at::Tensor &q,  // total_q x num_heads x head_size, total_q
         if (out_.has_value()) { out_.value().copy_(out); }
     }
 
+    if (seqlenq_ngroups_swapped) {
+        long size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size_og};
+        long size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size_og};
+        out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
+        out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
+        q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
+        softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * max_seqlen_q, 1});
+    }
+
     return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
 }
 
@@ -1367,23 +1417,10 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
         TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32");
         params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
     }
-    // This needs to match with run_mha_fwd_splitkv_dispatch
-    const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
-    const int num_n_blocks = (seqlen_k + block_n - 1) / block_n;
-    // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
-    // In any case we don't expect seqlen_q to be larger than 64 for inference.
-    const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
-    params.num_splits = num_splits;
-    if (num_splits < 1) {
-        params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
-    }
-    TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
-    if (params.num_splits > 1) {
-        at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
-        at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
-        params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
-        params.oaccum_ptr = out_accum.data_ptr();
-    }
+
+    set_params_splitkv(params, batch_size, num_heads,
+                       head_size, seqlen_k, seqlen_q,
+                       head_size_rounded, /*dropout*/0.f, num_splits, dprops, opts);
 
     if (alibi_slopes_.has_value()) {
         auto alibi_slopes = alibi_slopes_.value();