Prechádzať zdrojové kódy

Remove configure in bwd kernel launch

Tri Dao 1 rok pred
rodič
commit
ea8a25ca38

+ 8 - 24
csrc/flash_attn/flash_api.cpp

@@ -1,5 +1,5 @@
 /******************************************************************************
- * Copyright (c) 2023, Tri Dao.
+ * Copyright (c) 2024, Tri Dao.
  ******************************************************************************/
 
 // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
@@ -204,7 +204,7 @@ void set_params_dgrad(Flash_bwd_params &params,
 
 void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
     FP16_SWITCH(!params.is_bf16, [&] {
-        FWD_HEADDIM_SWITCH(params.d, [&] {
+        HEADDIM_SWITCH(params.d, [&] {
             if (params.num_splits <= 1 && !force_split_kernel) {  // If we don't set it num_splits == 0
                 run_mha_fwd_<elem_type, kHeadDim>(params, stream);
             } else {
@@ -695,25 +695,11 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
     return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
 }
 
-void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
+void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
     FP16_SWITCH(!params.is_bf16, [&] {
-        if (params.d <= 32) {
-            run_mha_bwd_<elem_type, 32>(params, stream, configure);
-        } else if (params.d <= 64) {
-            run_mha_bwd_<elem_type, 64>(params, stream, configure);
-        } else if (params.d <= 96) {
-            run_mha_bwd_<elem_type, 96>(params, stream, configure);
-        } else if (params.d <= 128) {
-            run_mha_bwd_<elem_type, 128>(params, stream, configure);
-        } else if (params.d <= 160) {
-            run_mha_bwd_<elem_type, 160>(params, stream, configure);
-        } else if (params.d <= 192) {
-            run_mha_bwd_<elem_type, 192>(params, stream, configure);
-        } else if (params.d <= 224) {
-          run_mha_bwd_<elem_type, 224>(params, stream, configure);
-        } else if (params.d <= 256) {
-          run_mha_bwd_<elem_type, 256>(params, stream, configure);
-        }
+        HEADDIM_SWITCH(params.d, [&] {
+            run_mha_bwd_<elem_type, kHeadDim>(params, stream);
+        });
     });
 }
 
@@ -898,7 +884,6 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
     params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
 
     auto launch = &run_mha_bwd;
-    // launch(params, stream, /*configure=*/true);
 
     auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
         gen_, at::cuda::detail::getDefaultCUDAGenerator());
@@ -930,7 +915,7 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
     }
 
     if (seqlen_q > 0) {
-        launch(params, stream, /*configure=*/false);
+        launch(params, stream);
     } else {
         // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
         dk_expanded.zero_();
@@ -1154,7 +1139,6 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
     params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
 
     auto launch = &run_mha_bwd;
-    // launch(params, stream, /*configure=*/true);
 
     auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
         gen_, at::cuda::detail::getDefaultCUDAGenerator());
@@ -1186,7 +1170,7 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
     }
 
     if (max_seqlen_q > 0) {
-        launch(params, stream, /*configure=*/false);
+        launch(params, stream);
     } else {
         // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
         dk_expanded.zero_();

+ 1 - 1
csrc/flash_attn/src/flash.h

@@ -182,4 +182,4 @@ struct Flash_bwd_params : public Flash_fwd_params {
 template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
 template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
 
-template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream, const bool configure);
+template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);

+ 2 - 2
csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
-    run_mha_bwd_hdim128<cutlass::bfloat16_t>(params, stream, configure);
+void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim128<cutlass::bfloat16_t>(params, stream);
 }

+ 2 - 2
csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
-    run_mha_bwd_hdim128<cutlass::half_t>(params, stream, configure);
+void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim128<cutlass::half_t>(params, stream);
 }

+ 2 - 2
csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::bfloat16_t, 160>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
-    run_mha_bwd_hdim160<cutlass::bfloat16_t>(params, stream, configure);
+void run_mha_bwd_<cutlass::bfloat16_t, 160>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim160<cutlass::bfloat16_t>(params, stream);
 }

+ 2 - 2
csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::half_t, 160>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
-    run_mha_bwd_hdim160<cutlass::half_t>(params, stream, configure);
+void run_mha_bwd_<cutlass::half_t, 160>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim160<cutlass::half_t>(params, stream);
 }

+ 2 - 2
csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::bfloat16_t, 192>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
-    run_mha_bwd_hdim192<cutlass::bfloat16_t>(params, stream, configure);
+void run_mha_bwd_<cutlass::bfloat16_t, 192>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim192<cutlass::bfloat16_t>(params, stream);
 }

+ 2 - 2
csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::half_t, 192>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
-    run_mha_bwd_hdim192<cutlass::half_t>(params, stream, configure);
+void run_mha_bwd_<cutlass::half_t, 192>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim192<cutlass::half_t>(params, stream);
 }

+ 2 - 2
csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::bfloat16_t, 224>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
-    run_mha_bwd_hdim224<cutlass::bfloat16_t>(params, stream, configure);
+void run_mha_bwd_<cutlass::bfloat16_t, 224>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim224<cutlass::bfloat16_t>(params, stream);
 }

+ 2 - 2
csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::half_t, 224>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
-    run_mha_bwd_hdim224<cutlass::half_t>(params, stream, configure);
+void run_mha_bwd_<cutlass::half_t, 224>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim224<cutlass::half_t>(params, stream);
 }

+ 2 - 2
csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::bfloat16_t, 256>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
-    run_mha_bwd_hdim256<cutlass::bfloat16_t>(params, stream, configure);
+void run_mha_bwd_<cutlass::bfloat16_t, 256>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim256<cutlass::bfloat16_t>(params, stream);
 }

+ 2 - 2
csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::half_t, 256>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
-    run_mha_bwd_hdim256<cutlass::half_t>(params, stream, configure);
+void run_mha_bwd_<cutlass::half_t, 256>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim256<cutlass::half_t>(params, stream);
 }

+ 2 - 2
csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::bfloat16_t, 32>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
-    run_mha_bwd_hdim32<cutlass::bfloat16_t>(params, stream, configure);
+void run_mha_bwd_<cutlass::bfloat16_t, 32>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim32<cutlass::bfloat16_t>(params, stream);
 }

+ 2 - 2
csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::half_t, 32>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
-    run_mha_bwd_hdim32<cutlass::half_t>(params, stream, configure);
+void run_mha_bwd_<cutlass::half_t, 32>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim32<cutlass::half_t>(params, stream);
 }

+ 2 - 2
csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::bfloat16_t, 64>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
-    run_mha_bwd_hdim64<cutlass::bfloat16_t>(params, stream, configure);
+void run_mha_bwd_<cutlass::bfloat16_t, 64>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim64<cutlass::bfloat16_t>(params, stream);
 }

+ 2 - 2
csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::half_t, 64>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
-    run_mha_bwd_hdim64<cutlass::half_t>(params, stream, configure);
+void run_mha_bwd_<cutlass::half_t, 64>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim64<cutlass::half_t>(params, stream);
 }

+ 2 - 2
csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::bfloat16_t, 96>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
-    run_mha_bwd_hdim96<cutlass::bfloat16_t>(params, stream, configure);
+void run_mha_bwd_<cutlass::bfloat16_t, 96>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim96<cutlass::bfloat16_t>(params, stream);
 }

+ 2 - 2
csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::half_t, 96>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
-    run_mha_bwd_hdim96<cutlass::half_t>(params, stream, configure);
+void run_mha_bwd_<cutlass::half_t, 96>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim96<cutlass::half_t>(params, stream);
 }

+ 54 - 55
csrc/flash_attn/src/flash_bwd_launch_template.h

@@ -43,7 +43,7 @@ __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
 }
 
 template<typename Kernel_traits, bool Is_dropout>
-void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
+void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream) {
     const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
     dim3 grid_m(num_m_block, params.b, params.h);
     const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
@@ -99,13 +99,12 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
 }
 
 template<typename Kernel_traits, bool Is_dropout>
-void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
-    if (configure) return;
-    run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream, configure);
+void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
+    run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream);
 }
 
 template<typename T>
-void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
+void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 32;
     int device;
     cudaGetDevice(&device);
@@ -118,18 +117,18 @@ void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream, const boo
     BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
             if constexpr(!Is_dropout) {  // We can afford more registers to keep V in registers
-                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
+                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
             } else {
-                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
+                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
             }
         } else {  // 96 KB
-            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
         }
     });
 }
 
 template<typename T>
-void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
+void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 64;
     int device;
     cudaGetDevice(&device);
@@ -142,39 +141,39 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream, const boo
     // printf("max_smem_per_block = %d\n", max_smem_per_block);
     BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         // Changing AtomLayoutMdQ from 2 to 4 takes the same time
-        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream, configure);
-        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream, configure);
-        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream, configure);
-        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
+        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream);
+        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
+        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
+        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
         // This is slightly faster. We want to split M more so we need fewer registers to store LSE.
         if (max_smem_per_block >= 144 * 1024) {
-            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
             // This has a lot of register spilling
-            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
+            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
         } else {
             // if (params.h == params.h_k) {
-                // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
-                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
-                // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
-                // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream, configure);
+                // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
+                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
+                // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
+                // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream);
             // } else {
             // }
         }
     });
-    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream, configure);
-    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream, configure);
-    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream, configure);
-    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream, configure);
+    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream);
+    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream);
+    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
+    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
     // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
-    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream, configure);
-    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream, configure);
-    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream, configure);
+    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream);
+    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
+    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
 
-    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream, configure);
+    // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream);
 }
 
 template<typename T>
-void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
+void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 96;
     int device;
     cudaGetDevice(&device);
@@ -188,19 +187,19 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream, const boo
     BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         if (max_smem_per_block >= 116 * 1024) {
             if constexpr(!Is_dropout) {  // 92KB
-                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
+                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
             } else {  // 116 KB
                 // This is faster for dropout since we don't have many registers to spare
-                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
+                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
             }
         } else {
-            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
         }
     });
 }
 
 template<typename T>
-void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
+void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 128;
     int device;
     cudaGetDevice(&device);
@@ -212,29 +211,29 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bo
     }
     // printf("max_smem_per_block = %d\n", max_smem_per_block);
     BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
-        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure);
+        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
         // This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
         // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
-        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure);
+        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
         if (max_smem_per_block >= 144 * 1024) {
-            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream, configure);
-            // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
-            // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream, configure);
-            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream, configure);
-            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
-            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream, configure);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream);
+            // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
+            // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
+            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream);
+            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
+            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
         } else {
-            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
-            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream, configure);
+            // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
         }
-        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream, configure);
+        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
 
-        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure);
+        // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream);
     });
 }
 
 template<typename T>
-void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
+void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 160;
     int device;
     cudaGetDevice(&device);
@@ -246,15 +245,15 @@ void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream, const bo
     }
     BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         if (max_smem_per_block >= 116 * 1024) {
-            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
         } else {
-            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream, configure);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
         }
     });
 }
 
 template<typename T>
-void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
+void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 192;
     int device;
     cudaGetDevice(&device);
@@ -266,23 +265,23 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream, const bo
     }
     BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         if (max_smem_per_block >= 136 * 1024) {
-            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
         } else {
-            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout>(params, stream, configure);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout>(params, stream);
         }
     });
 }
 
 template<typename T>
-void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
+void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 224;
     BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
-        run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
+        run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
     });
 }
 
 template<typename T>
-void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
+void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 256;
     int device;
     cudaGetDevice(&device);
@@ -294,9 +293,9 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream, const bo
     }
     BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         if (max_smem_per_block >= 176 * 1024) {  // H100
-            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
         } else {  // A100, we don't do double buffering to save smem
-            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream, configure);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream);
         }
     });
 }

+ 2 - 2
csrc/flash_attn/src/generate_kernels.py

@@ -32,8 +32,8 @@ template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params
 KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {{
-    run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream, configure);
+void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params &params, cudaStream_t stream) {{
+    run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
 }}
 """
 

+ 1 - 1
csrc/flash_attn/src/static_switch.h

@@ -36,7 +36,7 @@
     }                                        \
   }()
 
-#define FWD_HEADDIM_SWITCH(HEADDIM, ...)   \
+#define HEADDIM_SWITCH(HEADDIM, ...)   \
   [&] {                                    \
     if (HEADDIM <= 32) {                   \
       constexpr static int kHeadDim = 32;  \