Răsfoiți Sursa

Split into more .cu files to speed up compilation

Tri Dao 8 luni în urmă
părinte
comite
908511b2b6
69 a modificat fișierele cu 503 adăugiri și 210 ștergeri
  1. 7 5
      csrc/flash_attn/flash_api.cpp
  2. 2 2
      csrc/flash_attn/src/flash.h
  3. 10 0
      csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu
  4. 2 2
      csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu
  5. 10 0
      csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu
  6. 2 2
      csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu
  7. 10 0
      csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu
  8. 2 2
      csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu
  9. 10 0
      csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu
  10. 2 2
      csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu
  11. 10 0
      csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu
  12. 2 2
      csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu
  13. 10 0
      csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu
  14. 2 2
      csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu
  15. 10 0
      csrc/flash_attn/src/flash_fwd_hdim224_bf16_causal_sm80.cu
  16. 2 2
      csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu
  17. 10 0
      csrc/flash_attn/src/flash_fwd_hdim224_fp16_causal_sm80.cu
  18. 2 2
      csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu
  19. 10 0
      csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu
  20. 2 2
      csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu
  21. 10 0
      csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu
  22. 2 2
      csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
  23. 10 0
      csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu
  24. 2 2
      csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
  25. 10 0
      csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu
  26. 2 2
      csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
  27. 10 0
      csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu
  28. 2 2
      csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
  29. 10 0
      csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu
  30. 2 2
      csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
  31. 10 0
      csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu
  32. 2 2
      csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
  33. 10 0
      csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu
  34. 2 2
      csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
  35. 128 146
      csrc/flash_attn/src/flash_fwd_launch_template.h
  36. 7 0
      csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu
  37. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu
  38. 7 0
      csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu
  39. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu
  40. 7 0
      csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu
  41. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu
  42. 7 0
      csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu
  43. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu
  44. 7 0
      csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu
  45. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu
  46. 7 0
      csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu
  47. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu
  48. 7 0
      csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_causal_sm80.cu
  49. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu
  50. 7 0
      csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_causal_sm80.cu
  51. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu
  52. 7 0
      csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu
  53. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu
  54. 7 0
      csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu
  55. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu
  56. 7 0
      csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu
  57. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu
  58. 7 0
      csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu
  59. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu
  60. 7 0
      csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu
  61. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
  62. 7 0
      csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu
  63. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
  64. 7 0
      csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu
  65. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
  66. 7 0
      csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu
  67. 1 1
      csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu
  68. 14 9
      csrc/flash_attn/src/generate_kernels.py
  69. 32 0
      setup.py

+ 7 - 5
csrc/flash_attn/flash_api.cpp

@@ -238,11 +238,13 @@ void set_params_dgrad(Flash_bwd_params &params,
 void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
     FP16_SWITCH(!params.is_bf16, [&] {
         HEADDIM_SWITCH(params.d, [&] {
-            if (params.num_splits <= 1 && !force_split_kernel) {  // If we don't set it num_splits == 0
-                run_mha_fwd_<elem_type, kHeadDim>(params, stream);
-            } else {
-                run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
-            }
+            BOOL_SWITCH(params.is_causal, Is_causal, [&] {
+                if (params.num_splits <= 1 && !force_split_kernel) {  // If we don't set it num_splits == 0
+                    run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
+                } else {
+                    run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params, stream);
+                }
+            });
         });
     });
 }

+ 2 - 2
csrc/flash_attn/src/flash.h

@@ -188,7 +188,7 @@ struct Flash_bwd_params : public Flash_fwd_params {
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
-template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
+template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
+template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
 
 template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);

+ 10 - 0
csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim128<cutlass::bfloat16_t, true>(params, stream);
+}

+ 2 - 2
csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_fwd_launch_template.h"
 
 template<>
-void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
-    run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
+void run_mha_fwd_<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim128<cutlass::bfloat16_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim128<cutlass::half_t, true>(params, stream);
+}

+ 2 - 2
csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_fwd_launch_template.h"
 
 template<>
-void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
-    run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
+void run_mha_fwd_<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim128<cutlass::half_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim160<cutlass::bfloat16_t, true>(params, stream);
+}

+ 2 - 2
csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_fwd_launch_template.h"
 
 template<>
-void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
-    run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream);
+void run_mha_fwd_<cutlass::bfloat16_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim160<cutlass::bfloat16_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::half_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim160<cutlass::half_t, true>(params, stream);
+}

+ 2 - 2
csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_fwd_launch_template.h"
 
 template<>
-void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
-    run_mha_fwd_hdim160<cutlass::half_t>(params, stream);
+void run_mha_fwd_<cutlass::half_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim160<cutlass::half_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim192<cutlass::bfloat16_t, true>(params, stream);
+}

+ 2 - 2
csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_fwd_launch_template.h"
 
 template<>
-void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
-    run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream);
+void run_mha_fwd_<cutlass::bfloat16_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim192<cutlass::bfloat16_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::half_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim192<cutlass::half_t, true>(params, stream);
+}

+ 2 - 2
csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_fwd_launch_template.h"
 
 template<>
-void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
-    run_mha_fwd_hdim192<cutlass::half_t>(params, stream);
+void run_mha_fwd_<cutlass::half_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim192<cutlass::half_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_fwd_hdim224_bf16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 224, true>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim224<cutlass::bfloat16_t, true>(params, stream);
+}

+ 2 - 2
csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_fwd_launch_template.h"
 
 template<>
-void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
-    run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream);
+void run_mha_fwd_<cutlass::bfloat16_t, 224, false>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim224<cutlass::bfloat16_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_fwd_hdim224_fp16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::half_t, 224, true>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim224<cutlass::half_t, true>(params, stream);
+}

+ 2 - 2
csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_fwd_launch_template.h"
 
 template<>
-void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
-    run_mha_fwd_hdim224<cutlass::half_t>(params, stream);
+void run_mha_fwd_<cutlass::half_t, 224, false>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim224<cutlass::half_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim256<cutlass::bfloat16_t, true>(params, stream);
+}

+ 2 - 2
csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_fwd_launch_template.h"
 
 template<>
-void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
-    run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
+void run_mha_fwd_<cutlass::bfloat16_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim256<cutlass::bfloat16_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::half_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim256<cutlass::half_t, true>(params, stream);
+}

+ 2 - 2
csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_fwd_launch_template.h"
 
 template<>
-void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
-    run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
+void run_mha_fwd_<cutlass::half_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim256<cutlass::half_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim32<cutlass::bfloat16_t, true>(params, stream);
+}

+ 2 - 2
csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_fwd_launch_template.h"
 
 template<>
-void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
-    run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream);
+void run_mha_fwd_<cutlass::bfloat16_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim32<cutlass::bfloat16_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::half_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim32<cutlass::half_t, true>(params, stream);
+}

+ 2 - 2
csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_fwd_launch_template.h"
 
 template<>
-void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
-    run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
+void run_mha_fwd_<cutlass::half_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim32<cutlass::half_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim64<cutlass::bfloat16_t, true>(params, stream);
+}

+ 2 - 2
csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_fwd_launch_template.h"
 
 template<>
-void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
-    run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream);
+void run_mha_fwd_<cutlass::bfloat16_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim64<cutlass::bfloat16_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::half_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim64<cutlass::half_t, true>(params, stream);
+}

+ 2 - 2
csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_fwd_launch_template.h"
 
 template<>
-void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
-    run_mha_fwd_hdim64<cutlass::half_t>(params, stream);
+void run_mha_fwd_<cutlass::half_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim64<cutlass::half_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim96<cutlass::bfloat16_t, true>(params, stream);
+}

+ 2 - 2
csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_fwd_launch_template.h"
 
 template<>
-void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
-    run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream);
+void run_mha_fwd_<cutlass::bfloat16_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim96<cutlass::bfloat16_t, false>(params, stream);
 }

+ 10 - 0
csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu

@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::half_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim96<cutlass::half_t, true>(params, stream);
+}

+ 2 - 2
csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu

@@ -5,6 +5,6 @@
 #include "flash_fwd_launch_template.h"
 
 template<>
-void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
-    run_mha_fwd_hdim96<cutlass::half_t>(params, stream);
+void run_mha_fwd_<cutlass::half_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim96<cutlass::half_t, false>(params, stream);
 }

+ 128 - 146
csrc/flash_attn/src/flash_fwd_launch_template.h

@@ -95,7 +95,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
     });
 }
 
-template<typename Kernel_traits>
+template<typename Kernel_traits, bool Is_causal>
 void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
     static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs");
     static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem");
@@ -104,27 +104,25 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
     dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h);
     const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
     const bool is_even_K = params.d == Kernel_traits::kHeadDim;
-    BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-        BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
-            EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
-                LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
-                    BOOL_SWITCH(params.num_splits > 1, Split, [&] {
-                        BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
-                            ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
-                                SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
-                                    // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
-                                    // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
-                                    // If Is_local, set Is_causal to false
-                                    auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV>;
-                                    // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
-                                    // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
-                                    if (smem_size >= 48 * 1024) {
-                                        C10_CUDA_CHECK(cudaFuncSetAttribute(
-                                            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
-                                    }
-                                    kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
-                                    C10_CUDA_KERNEL_LAUNCH_CHECK();
-                                });
+    BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
+        EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
+            LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
+                BOOL_SWITCH(params.num_splits > 1, Split, [&] {
+                    BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
+                        ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
+                            SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
+                                // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
+                                // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
+                                // If Is_local, set Is_causal to false
+                                auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV>;
+                                // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
+                                // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
+                                if (smem_size >= 48 * 1024) {
+                                    C10_CUDA_CHECK(cudaFuncSetAttribute(
+                                        kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+                                }
+                                kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
+                                C10_CUDA_KERNEL_LAUNCH_CHECK();
                             });
                         });
                     });
@@ -159,161 +157,149 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
     }
 }
 
-template<typename T, int Headdim>
+template<typename T, int Headdim, bool Is_causal>
 void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int kBlockM = 64;  // Fixed for all head dimensions
     // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
     // and for headdim 192 with block size 64 x 128.
     // Also for headdim 160 with block size 64 x 128 after the rotary addition.
     constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
-    run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
+    run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream);
 }
 
-template<typename T>
+template<typename T, bool Is_causal>
 void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 32;
     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
-        BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-        });
+        run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
     });
 }
 
-template<typename T>
+template<typename T, bool Is_causal>
 void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 64;
     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
-        BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-            if constexpr(!Is_dropout) {
-                // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
-                // Using block size (64 x 256) is 27% slower for seqlen=2k
-                // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
-                run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-                // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
-                // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
-            } else {
-                run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-                // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
-                // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
-                // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-            }
-        });
+        if constexpr(!Is_dropout) {
+            // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
+            // Using block size (64 x 256) is 27% slower for seqlen=2k
+            // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
+        } else {
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        }
     });
 }
 
-template<typename T>
+template<typename T, bool Is_causal>
 void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 96;
     auto dprops = at::cuda::getCurrentDeviceProperties();
     bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
-        BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-            // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
-            if (is_sm8x) {
-                if constexpr(!Is_causal) {
-                    run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-                } else {
-                    run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-                }
-            } else {
+        // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
+        if (is_sm8x) {
+            if constexpr(!Is_causal) {
                 run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+            } else {
+                run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
             }
-            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
-            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
-            // These two are always slower
-            // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
-            // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
-        });
+        } else {
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        }
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
+        // These two are always slower
+        // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
     });
 }
 
-template<typename T>
+template<typename T, bool Is_causal>
 void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 128;
     auto dprops = at::cuda::getCurrentDeviceProperties();
     bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
-        BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-            if constexpr(!Is_dropout) {
-                // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
-                // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
-                if (is_sm8x) {
-                    if constexpr(!Is_causal) {
-                        run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-                    } else {
-                        run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-                    }
+        if constexpr(!Is_dropout) {
+            // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
+            // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
+            if (is_sm8x) {
+                if constexpr(!Is_causal) {
+                    run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
                 } else {
-                    run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+                    run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
                 }
-                // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
-                // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
-                // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-                // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
-                // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
-                // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
-                // 1st ones are good for H100, A100
-                // 2nd one is good for A6000 bc we get slightly better occupancy
             } else {
-                run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-                // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-                // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
-                // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
+                run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
             }
-        });
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+            // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
+            // 1st ones are good for H100, A100
+            // 2nd one is good for A6000 bc we get slightly better occupancy
+        } else {
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
+        }
     });
 }
 
-template<typename T>
+template<typename T, bool Is_causal>
 void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 160;
     auto dprops = at::cuda::getCurrentDeviceProperties();
     bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
-        BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-            // For A100, H100, 128 x 32 is the fastest.
-            // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
-            // and 128 x 64 with 8 warps is the fastest for non-causal.
-            if (is_sm8x) {
-                if constexpr(!Is_causal) {
-                    run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
-                } else {
-                    run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-                }
+        // For A100, H100, 128 x 32 is the fastest.
+        // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
+        // and 128 x 64 with 8 warps is the fastest for non-causal.
+        if (is_sm8x) {
+            if constexpr(!Is_causal) {
+                run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
             } else {
-                run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+                run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
             }
-            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
-            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
-            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
-            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
-            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
-            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
-        });
+        } else {
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        }
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
     });
 }
 
-template<typename T>
+template<typename T, bool Is_causal>
 void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 192;
     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
-        BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-            if constexpr(!Is_dropout) {
-                run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
-            } else {
-                run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-            }
-            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
-            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
-            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
-            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
-        });
+        if constexpr(!Is_dropout) {
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        } else {
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        }
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
     });
 }
 
-template<typename T>
+template<typename T, bool Is_causal>
 void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 224;
     int device;
@@ -326,23 +312,21 @@ void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
     }
     // printf("max_smem_per_block = %d\n", max_smem_per_block);
     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
-        BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-            if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) {  // 112 KB
-                run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
-            } else {
-                run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-            }
-            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-            // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
-            // If we have N = 32, there are only 1024 elements to load at once, where each load
-            // is 8 elements. This means we can only use 128 threads and not 256 threads.
-            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
-        });
+        if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) {  // 112 KB
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        } else {
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        }
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
+        // If we have N = 32, there are only 1024 elements to load at once, where each load
+        // is 8 elements. This means we can only use 128 threads and not 256 threads.
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
     });
 }
 
-template<typename T>
+template<typename T, bool Is_causal>
 void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 256;
     int device;
@@ -357,18 +341,16 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
     }
     // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
-        BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-            // For A100, we want to run with 128 x 64 (128KB smem).
-            // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
-            if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
-                run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
-            } else {
-                run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-            }
-            // 64 KB
-            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-            // 96 KB
-            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
-        });
+        // For A100, we want to run with 128 x 64 (128KB smem).
+        // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
+        if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        } else {
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        }
+        // 64 KB
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        // 96 KB
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
     });
 }

+ 7 - 0
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu

@@ -4,4 +4,4 @@
 
 #include "flash_fwd_launch_template.h"
 
-template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream);
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu

@@ -4,4 +4,4 @@
 
 #include "flash_fwd_launch_template.h"
 
-template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream);
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu

@@ -4,4 +4,4 @@
 
 #include "flash_fwd_launch_template.h"
 
-template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream);
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu

@@ -4,4 +4,4 @@
 
 #include "flash_fwd_launch_template.h"
 
-template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream);
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu

@@ -4,4 +4,4 @@
 
 #include "flash_fwd_launch_template.h"
 
-template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream);
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu

@@ -4,4 +4,4 @@
 
 #include "flash_fwd_launch_template.h"
 
-template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream);
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 224, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu

@@ -4,4 +4,4 @@
 
 #include "flash_fwd_launch_template.h"
 
-template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 224>(Flash_fwd_params &params, cudaStream_t stream);
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 224, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 224, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu

@@ -4,4 +4,4 @@
 
 #include "flash_fwd_launch_template.h"
 
-template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 224>(Flash_fwd_params &params, cudaStream_t stream);
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 224, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu

@@ -4,4 +4,4 @@
 
 #include "flash_fwd_launch_template.h"
 
-template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 256>(Flash_fwd_params &params, cudaStream_t stream);
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu

@@ -4,4 +4,4 @@
 
 #include "flash_fwd_launch_template.h"
 
-template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 256>(Flash_fwd_params &params, cudaStream_t stream);
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu

@@ -4,4 +4,4 @@
 
 #include "flash_fwd_launch_template.h"
 
-template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 32>(Flash_fwd_params &params, cudaStream_t stream);
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu

@@ -4,4 +4,4 @@
 
 #include "flash_fwd_launch_template.h"
 
-template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream);
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu

@@ -4,4 +4,4 @@
 
 #include "flash_fwd_launch_template.h"
 
-template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream);
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu

@@ -4,4 +4,4 @@
 
 #include "flash_fwd_launch_template.h"
 
-template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream);
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu

@@ -4,4 +4,4 @@
 
 #include "flash_fwd_launch_template.h"
 
-template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream);
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 1 - 1
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu

@@ -4,4 +4,4 @@
 
 #include "flash_fwd_launch_template.h"
 
-template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream);
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 14 - 9
csrc/flash_attn/src/generate_kernels.py

@@ -16,17 +16,18 @@ DTYPE_MAP = {
 
 SM = [80]  # Sm80 kernels support up to
 HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 224, 256]
+IS_CAUSAL = ["false", "true"]
 KERNEL_IMPL_TEMPLATE_FWD = """#include "flash_fwd_launch_template.h"
 
 template<>
-void run_mha_fwd_<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params &params, cudaStream_t stream) {{
-    run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
+void run_mha_fwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream) {{
+    run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream);
 }}
 """
 
 KERNEL_IMPL_TEMPLATE_FWD_SPLIT = """#include "flash_fwd_launch_template.h"
 
-template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params &params, cudaStream_t stream);
+template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream);
 """
 
 KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h"
@@ -43,13 +44,14 @@ class Kernel:
     sm: int
     dtype: str
     head_dim: int
+    is_causal: bool
     direction: str
 
     @property
     def template(self) -> str:
         if self.direction == "fwd":
             return KERNEL_IMPL_TEMPLATE_FWD.format(
-                DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
+                DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal
             )
         elif self.direction == "bwd":
             return KERNEL_IMPL_TEMPLATE_BWD.format(
@@ -57,18 +59,21 @@ class Kernel:
             )
         else:
             return KERNEL_IMPL_TEMPLATE_FWD_SPLIT.format(
-                DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
+                DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal
             )
 
     @property
     def filename(self) -> str:
-        return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_sm{self.sm}.cu"
+        return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_{'causal_' if self.is_causal == 'true' else ''}sm{self.sm}.cu"
 
 
 def get_all_kernels() -> List[Kernel]:
-    for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM):
-        for direction in ["fwd", "bwd", "fwd_split"]:
-            yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, direction=direction)
+    for direction in ["fwd", "fwd_split"]:
+        for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM):
+            yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction)
+    for direction in ["bwd"]:
+        for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM):
+            yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal="false", direction=direction)
 
 
 def write_kernel(kernel: Kernel, autogen_dir: Path) -> None:

+ 32 - 0
setup.py

@@ -151,6 +151,22 @@ if not SKIP_CUDA_BUILD:
                 "csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_hdim224_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_hdim224_bf16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
                 "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
                 "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
                 "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
@@ -183,6 +199,22 @@ if not SKIP_CUDA_BUILD:
                 "csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
+                "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
             ],
             extra_compile_args={
                 "cxx": ["-O3", "-std=c++17"] + generator_flag,