浏览代码

Split bwd into more .cu files to speed up compilation

Tri Dao 8 月之前
父节点
当前提交
65f723bb9a
共有 89 个文件被更改,包括 304 次插入153 次删除
  1. 3 1
      csrc/flash_attn/flash_api.cpp
  2. 1 1
      csrc/flash_attn/src/flash.h
  3. 10 0
      csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu
  4. 3 3
      csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu
  5. 10 0
      csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu
  6. 3 3
      csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu
  7. 10 0
      csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu
  8. 3 3
      csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu
  9. 10 0
      csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu
  10. 3 3
      csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu
  11. 10 0
      csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu
  12. 3 3
      csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu
  13. 10 0
      csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu
  14. 3 3
      csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu
  15. 10 0
      csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu
  16. 3 3
      csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu
  17. 10 0
      csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu
  18. 3 3
      csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu
  19. 10 0
      csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu
  20. 3 3
      csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu
  21. 10 0
      csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu
  22. 3 3
      csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu
  23. 10 0
      csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu
  24. 3 3
      csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu
  25. 10 0
      csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu
  26. 3 3
      csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu
  27. 10 0
      csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu
  28. 3 3
      csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu
  29. 10 0
      csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu
  30. 3 3
      csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu
  31. 43 45
      csrc/flash_attn/src/flash_bwd_launch_template.h
  32. 1 1
      csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu
  33. 1 1
      csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu
  34. 1 1
      csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu
  35. 1 1
      csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu
  36. 1 1
      csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu
  37. 1 1
      csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu
  38. 1 1
      csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu
  39. 1 1
      csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu
  40. 1 1
      csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu
  41. 1 1
      csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu
  42. 1 1
      csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu
  43. 1 1
      csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu
  44. 1 1
      csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu
  45. 1 1
      csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu
  46. 1 1
      csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu
  47. 1 1
      csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
  48. 1 1
      csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu
  49. 1 1
      csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
  50. 1 1
      csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu
  51. 1 1
      csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
  52. 1 1
      csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu
  53. 1 1
      csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
  54. 1 1
      csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu
  55. 1 1
      csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
  56. 1 1
      csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu
  57. 1 1
      csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
  58. 1 1
      csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu
  59. 1 1
      csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
  60. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu
  61. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu
  62. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu
  63. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu
  64. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu
  65. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu
  66. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu
  67. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu
  68. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu
  69. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu
  70. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu
  71. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu
  72. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu
  73. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu
  74. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu
  75. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu
  76. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu
  77. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu
  78. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu
  79. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu
  80. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu
  81. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
  82. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu
  83. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
  84. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu
  85. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
  86. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu
  87. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu
  88. 5 8
      csrc/flash_attn/src/generate_kernels.py
  89. 14 0
      setup.py

+ 3 - 1
csrc/flash_attn/flash_api.cpp

@@ -800,7 +800,9 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
 void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
     FP16_SWITCH(!params.is_bf16, [&] {
         HEADDIM_SWITCH(params.d, [&] {
-            run_mha_bwd_<elem_type, kHeadDim>(params, stream);
+            BOOL_SWITCH(params.is_causal, Is_causal, [&] {
+                run_mha_bwd_<elem_type, kHeadDim, Is_causal>(params, stream);
+            });
         });
     });
 }

+ 1 - 1
csrc/flash_attn/src/flash.h

@@ -192,4 +192,4 @@ struct Flash_bwd_params : public Flash_fwd_params {
 template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
 template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
 
-template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
+template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);

+ 10 - 0
csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2024, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_bwd_launch_template.h"
+
+template<>
+void run_mha_bwd_<cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim128<cutlass::bfloat16_t, true>(params, stream);
+}

+ 3 - 3
csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu

@@ -1,10 +1,10 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params &params, cudaStream_t stream) {
-    run_mha_bwd_hdim128<cutlass::bfloat16_t>(params, stream);
+void run_mha_bwd_<cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim128<cutlass::bfloat16_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2024, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_bwd_launch_template.h"
+
+template<>
+void run_mha_bwd_<cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim128<cutlass::half_t, true>(params, stream);
+}

+ 3 - 3
csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu

@@ -1,10 +1,10 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params &params, cudaStream_t stream) {
-    run_mha_bwd_hdim128<cutlass::half_t>(params, stream);
+void run_mha_bwd_<cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim128<cutlass::half_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2024, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_bwd_launch_template.h"
+
+template<>
+void run_mha_bwd_<cutlass::bfloat16_t, 160, true>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim160<cutlass::bfloat16_t, true>(params, stream);
+}

+ 3 - 3
csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu

@@ -1,10 +1,10 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::bfloat16_t, 160>(Flash_bwd_params &params, cudaStream_t stream) {
-    run_mha_bwd_hdim160<cutlass::bfloat16_t>(params, stream);
+void run_mha_bwd_<cutlass::bfloat16_t, 160, false>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim160<cutlass::bfloat16_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2024, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_bwd_launch_template.h"
+
+template<>
+void run_mha_bwd_<cutlass::half_t, 160, true>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim160<cutlass::half_t, true>(params, stream);
+}

+ 3 - 3
csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu

@@ -1,10 +1,10 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::half_t, 160>(Flash_bwd_params &params, cudaStream_t stream) {
-    run_mha_bwd_hdim160<cutlass::half_t>(params, stream);
+void run_mha_bwd_<cutlass::half_t, 160, false>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim160<cutlass::half_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2024, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_bwd_launch_template.h"
+
+template<>
+void run_mha_bwd_<cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim192<cutlass::bfloat16_t, true>(params, stream);
+}

+ 3 - 3
csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu

@@ -1,10 +1,10 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::bfloat16_t, 192>(Flash_bwd_params &params, cudaStream_t stream) {
-    run_mha_bwd_hdim192<cutlass::bfloat16_t>(params, stream);
+void run_mha_bwd_<cutlass::bfloat16_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim192<cutlass::bfloat16_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2024, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_bwd_launch_template.h"
+
+template<>
+void run_mha_bwd_<cutlass::half_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim192<cutlass::half_t, true>(params, stream);
+}

+ 3 - 3
csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu

@@ -1,10 +1,10 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::half_t, 192>(Flash_bwd_params &params, cudaStream_t stream) {
-    run_mha_bwd_hdim192<cutlass::half_t>(params, stream);
+void run_mha_bwd_<cutlass::half_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim192<cutlass::half_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2024, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_bwd_launch_template.h"
+
+template<>
+void run_mha_bwd_<cutlass::bfloat16_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim256<cutlass::bfloat16_t, true>(params, stream);
+}

+ 3 - 3
csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu

@@ -1,10 +1,10 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::bfloat16_t, 256>(Flash_bwd_params &params, cudaStream_t stream) {
-    run_mha_bwd_hdim256<cutlass::bfloat16_t>(params, stream);
+void run_mha_bwd_<cutlass::bfloat16_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim256<cutlass::bfloat16_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2024, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_bwd_launch_template.h"
+
+template<>
+void run_mha_bwd_<cutlass::half_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim256<cutlass::half_t, true>(params, stream);
+}

+ 3 - 3
csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu

@@ -1,10 +1,10 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::half_t, 256>(Flash_bwd_params &params, cudaStream_t stream) {
-    run_mha_bwd_hdim256<cutlass::half_t>(params, stream);
+void run_mha_bwd_<cutlass::half_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim256<cutlass::half_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2024, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_bwd_launch_template.h"
+
+template<>
+void run_mha_bwd_<cutlass::bfloat16_t, 32, true>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim32<cutlass::bfloat16_t, true>(params, stream);
+}

+ 3 - 3
csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu

@@ -1,10 +1,10 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::bfloat16_t, 32>(Flash_bwd_params &params, cudaStream_t stream) {
-    run_mha_bwd_hdim32<cutlass::bfloat16_t>(params, stream);
+void run_mha_bwd_<cutlass::bfloat16_t, 32, false>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim32<cutlass::bfloat16_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2024, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_bwd_launch_template.h"
+
+template<>
+void run_mha_bwd_<cutlass::half_t, 32, true>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim32<cutlass::half_t, true>(params, stream);
+}

+ 3 - 3
csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu

@@ -1,10 +1,10 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::half_t, 32>(Flash_bwd_params &params, cudaStream_t stream) {
-    run_mha_bwd_hdim32<cutlass::half_t>(params, stream);
+void run_mha_bwd_<cutlass::half_t, 32, false>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim32<cutlass::half_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2024, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_bwd_launch_template.h"
+
+template<>
+void run_mha_bwd_<cutlass::bfloat16_t, 64, true>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim64<cutlass::bfloat16_t, true>(params, stream);
+}

+ 3 - 3
csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu

@@ -1,10 +1,10 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::bfloat16_t, 64>(Flash_bwd_params &params, cudaStream_t stream) {
-    run_mha_bwd_hdim64<cutlass::bfloat16_t>(params, stream);
+void run_mha_bwd_<cutlass::bfloat16_t, 64, false>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim64<cutlass::bfloat16_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2024, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_bwd_launch_template.h"
+
+template<>
+void run_mha_bwd_<cutlass::half_t, 64, true>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim64<cutlass::half_t, true>(params, stream);
+}

+ 3 - 3
csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu

@@ -1,10 +1,10 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::half_t, 64>(Flash_bwd_params &params, cudaStream_t stream) {
-    run_mha_bwd_hdim64<cutlass::half_t>(params, stream);
+void run_mha_bwd_<cutlass::half_t, 64, false>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim64<cutlass::half_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2024, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_bwd_launch_template.h"
+
+template<>
+void run_mha_bwd_<cutlass::bfloat16_t, 96, true>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim96<cutlass::bfloat16_t, true>(params, stream);
+}

+ 3 - 3
csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu

@@ -1,10 +1,10 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::bfloat16_t, 96>(Flash_bwd_params &params, cudaStream_t stream) {
-    run_mha_bwd_hdim96<cutlass::bfloat16_t>(params, stream);
+void run_mha_bwd_<cutlass::bfloat16_t, 96, false>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim96<cutlass::bfloat16_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2024, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_bwd_launch_template.h"
+
+template<>
+void run_mha_bwd_<cutlass::half_t, 96, true>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim96<cutlass::half_t, true>(params, stream);
+}

+ 3 - 3
csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu

@@ -1,10 +1,10 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 
 #include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<cutlass::half_t, 96>(Flash_bwd_params &params, cudaStream_t stream) {
-    run_mha_bwd_hdim96<cutlass::half_t>(params, stream);
+void run_mha_bwd_<cutlass::half_t, 96, false>(Flash_bwd_params &params, cudaStream_t stream) {
+    run_mha_bwd_hdim96<cutlass::half_t, false>(params, stream);
 }

+ 43 - 45
csrc/flash_attn/src/flash_bwd_launch_template.h

@@ -65,7 +65,7 @@ __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
     flash::convert_dKV<Kernel_traits>(params);
 }
 
-template<typename Kernel_traits, bool Is_dropout>
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
 void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream) {
     const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
     dim3 grid_m(num_m_block, params.b, params.h);
@@ -90,24 +90,22 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
     const bool is_even_K = params.d == Kernel_traits::kHeadDim;
     constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
     // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
-    BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-        BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
-            EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
-                LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
-                    ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
-                        SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
-                            // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
-                            // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
-                            // If Is_local, set Is_causal to false
-                            auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
-                            // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
-                            if (smem_size_dq_dk_dv >= 48 * 1024)  {
-                                C10_CUDA_CHECK(cudaFuncSetAttribute(
-                                    kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
-                            }
-                            kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
-                            C10_CUDA_KERNEL_LAUNCH_CHECK();
-                        });
+    BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
+        EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
+            LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
+                ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
+                    SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
+                        // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
+                        // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
+                        // If Is_local, set Is_causal to false
+                        auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
+                        // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
+                        if (smem_size_dq_dk_dv >= 48 * 1024)  {
+                            C10_CUDA_CHECK(cudaFuncSetAttribute(
+                                kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
+                        }
+                        kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
+                        C10_CUDA_KERNEL_LAUNCH_CHECK();
                     });
                 });
             });
@@ -123,14 +121,14 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
     C10_CUDA_KERNEL_LAUNCH_CHECK();
 }
 
-template<typename Kernel_traits, bool Is_dropout>
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
 void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
 #ifndef FLASHATTENTION_DISABLE_BACKWARD
-    run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream);
+    run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout, Is_causal>(params, stream);
 #endif
 }
 
-template<typename T>
+template<typename T, bool Is_causal>
 void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 32;
     int device;
@@ -144,17 +142,17 @@ void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
             if constexpr(!Is_dropout) {  // We can afford more registers to keep V in registers
-                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
+                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
             } else {
-                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
+                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
             }
         } else {  // 96 KB
-            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
         }
     });
 }
 
-template<typename T>
+template<typename T, bool Is_causal>
 void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 64;
     int device;
@@ -174,13 +172,13 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
         // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
         // This is slightly faster. We want to split M more so we need fewer registers to store LSE.
         if (max_smem_per_block >= 144 * 1024) {
-            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
             // This has a lot of register spilling
             // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
         } else {
             // if (params.h == params.h_k) {
                 // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
-                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
                 // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
                 // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream);
             // } else {
@@ -199,7 +197,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
     // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream);
 }
 
-template<typename T>
+template<typename T, bool Is_causal>
 void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 96;
     int device;
@@ -214,18 +212,18 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         if (max_smem_per_block >= 116 * 1024) {
             if constexpr(!Is_dropout) {  // 92KB
-                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
+                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
             } else {  // 116 KB
                 // This is faster for dropout since we don't have many registers to spare
-                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
+                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
             }
         } else {
-            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
         }
     });
 }
 
-template<typename T>
+template<typename T, bool Is_causal>
 void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 128;
     int device;
@@ -243,7 +241,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
         // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
         // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
         if (max_smem_per_block >= 144 * 1024) {
-            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout, Is_causal>(params, stream);
             // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
             // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
             // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream);
@@ -251,7 +249,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
             // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
         } else {
             // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
-            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout, Is_causal>(params, stream);
         }
         // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
 
@@ -259,7 +257,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
     });
 }
 
-template<typename T>
+template<typename T, bool Is_causal>
 void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 160;
     int device;
@@ -272,14 +270,14 @@ void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream) {
     }
     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         if (max_smem_per_block >= 116 * 1024) {
-            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
         } else {
-            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
         }
     });
 }
 
-template<typename T>
+template<typename T, bool Is_causal>
 void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 192;
     int device;
@@ -292,14 +290,14 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
     }
     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         if (max_smem_per_block >= 136 * 1024) {
-            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout, Is_causal>(params, stream);
         } else {
-            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout>(params, stream);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout, Is_causal>(params, stream);
         }
     });
 }
 
-template<typename T>
+template<typename T, bool Is_causal>
 void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 256;
     int device;
@@ -312,12 +310,12 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
     }
     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         if (max_smem_per_block >= 176 * 1024) {  // H100
-            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout, Is_causal>(params, stream);
         } else if (max_smem_per_block >= 144 * 1024) {  // A100, we don't do double buffering to save smem
-            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream);
+            run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout, Is_causal>(params, stream);
         } else { // sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering.
             if constexpr (!Is_dropout) {
-                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, false>(params, stream);
+                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, false, Is_causal>(params, stream);
             }
         }
     });

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu

@@ -1,4 +1,4 @@
-// Copyright (c) 2023, Tri Dao.
+// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
 

+ 5 - 8
csrc/flash_attn/src/generate_kernels.py

@@ -33,8 +33,8 @@ template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Fla
 KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h"
 
 template<>
-void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params &params, cudaStream_t stream) {{
-    run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
+void run_mha_bwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_bwd_params &params, cudaStream_t stream) {{
+    run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream);
 }}
 """
 
@@ -55,7 +55,7 @@ class Kernel:
             )
         elif self.direction == "bwd":
             return KERNEL_IMPL_TEMPLATE_BWD.format(
-                DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
+                DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal
             )
         else:
             return KERNEL_IMPL_TEMPLATE_FWD_SPLIT.format(
@@ -68,16 +68,13 @@ class Kernel:
 
 
 def get_all_kernels() -> List[Kernel]:
-    for direction in ["fwd", "fwd_split"]:
+    for direction in ["fwd", "fwd_split", "bwd"]:
         for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM):
             yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction)
-    for direction in ["bwd"]:
-        for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM):
-            yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal="false", direction=direction)
 
 
 def write_kernel(kernel: Kernel, autogen_dir: Path) -> None:
-    prelude = """// Copyright (c) 2023, Tri Dao.
+    prelude = """// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"\n
 """

+ 14 - 0
setup.py

@@ -222,6 +222,20 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
                 "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
                 "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
                 "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
+                "csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",