Driss Guessous 8 месяцев назад
Родитель
Сommit
1c275eb070
1 измененных файлов с 25 добавлено и 14 удалено
  1. 25 14
      csrc/flash_attn/flash_api.cpp

+ 25 - 14
csrc/flash_attn/flash_api.cpp

@@ -291,7 +291,7 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n
     return 1;
 }
 
-void set_params_splitkv(Flash_fwd_params &params, const int batch_size,
+std::tuple<at::Tensor, at::Tensor> set_params_splitkv(Flash_fwd_params &params, const int batch_size,
     const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
     const int head_size_rounded, const float p_dropout,
     const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) {
@@ -303,19 +303,24 @@ void set_params_splitkv(Flash_fwd_params &params, const int batch_size,
     // In any case we don't expect seqlen_q to be larger than 64 for inference.
     const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;
     params.num_splits = num_splits;
+    at::Tensor softmax_lse_accum;
+    at::Tensor out_accum;
+
     if (p_dropout == 0.0f) {  // SplitKV is not implemented for dropout
         if (num_splits < 1) {
             // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
             params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2, num_n_blocks, 128);
         }
         if (params.num_splits > 1) {
-            at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
-            at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
+            softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
+            out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
             params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
             params.oaccum_ptr = out_accum.data_ptr();
         }
         TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
     }
+
+    return std::make_tuple(softmax_lse_accum, out_accum);
 }
 
 void set_params_alibi(Flash_fwd_params &params, c10::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads){
@@ -476,10 +481,11 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
                      softcap
                      );
 
-
-    set_params_splitkv(params, batch_size, num_heads,
-                       head_size, seqlen_k, seqlen_q,
-                       head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
+    // Keep references to these tensors to extend their lifetime
+    at::Tensor softmax_lse_accum, out_accum;
+    std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
+        params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
+        head_size_rounded, p_dropout, /*num_splits*/ 0, dprops, opts);
 
     // number of times random will be generated per thread, to offset philox counter in thc random
     // state
@@ -725,11 +731,14 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
         params.v_batch_stride = v_padded.stride(0);
     }
     params.page_block_size = page_block_size;
+    // Keep references to these tensors to extend their lifetime
+    at::Tensor softmax_lse_accum, out_accum;
     if (seqlenq_ngroups_swapped) {
         // Only apply split-k for decoding
-        set_params_splitkv(params, batch_size, num_heads,
-                           head_size, max_seqlen_k, max_seqlen_q,
-                           head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
+        std::tie(softmax_lse_accum, out_accum) =
+            set_params_splitkv(params, batch_size, num_heads, head_size,
+                               max_seqlen_k, max_seqlen_q, head_size_rounded,
+                               p_dropout, /*num_splits*/ 0, dprops, opts);
     }
 
     if (leftpad_k_.has_value()) {
@@ -1524,10 +1533,12 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
         TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32");
         params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
     }
-
-    set_params_splitkv(params, batch_size, num_heads,
-                       head_size, seqlen_k, seqlen_q,
-                       head_size_rounded, /*dropout*/0.f, num_splits, dprops, opts);
+    
+    // Keep references to these tensors to extend their lifetime
+    at::Tensor softmax_lse_accum, out_accum;
+    std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
+        params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
+        head_size_rounded, /*dropout*/ 0.f, num_splits, dprops, opts);
 
     if (paged_KV) {
         params.block_table = block_table.data_ptr<int>();