Răsfoiți Sursa

Transpose out when swapping seqlen_q and num_groups

Tri Dao 11 luni în urmă
părinte
comite
9eb3d099c1
1 a modificat fișierele cu 12 adăugiri și 4 ștergeri
  1. 12 4
      csrc/flash_attn/flash_api.cpp

+ 12 - 4
csrc/flash_attn/flash_api.cpp

@@ -282,7 +282,8 @@ void set_params_splitkv(Flash_fwd_params &params, const int batch_size,
     params.num_splits = num_splits;
     if (p_dropout == 0.0f) {  // SplitKV is not implemented for dropout
         if (num_splits < 1) {
-            params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
+            // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
+            params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2, num_n_blocks, 128);
         }
         if (params.num_splits > 1) {
             at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
@@ -372,8 +373,8 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
     // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
     // H/t Daniel Haziza
     const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
+    const int ngroups = num_heads / num_heads_k;
     if (seqlenq_ngroups_swapped) {
-        const int ngroups = num_heads / num_heads_k;
         q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
         seqlen_q = ngroups;
         num_heads = num_heads_k;
@@ -400,7 +401,10 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
         TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
         CHECK_DEVICE(out);
         TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
-        CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
+        CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og);
+        if (seqlenq_ngroups_swapped) {
+            out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
+        }
         if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
     } else {
         out = torch::empty_like(q_padded);
@@ -571,8 +575,8 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
     // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
     // H/t Daniel Haziza
     const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
+    const int ngroups = num_heads / num_heads_k;
     if (seqlenq_ngroups_swapped) {
-        const int ngroups = num_heads / num_heads_k;
         q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
         max_seqlen_q = ngroups;
         num_heads = num_heads_k;
@@ -627,6 +631,10 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
         CHECK_DEVICE(out);
         TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
         CHECK_SHAPE(out, total_q, num_heads, head_size_og);
+        CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
+        if (seqlenq_ngroups_swapped) {
+            out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
+        }
         if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
     } else {
         out = torch::empty_like(q_padded);