Pārlūkot izejas kodu

Add a macro for namespace (#1419)

Driss Guessous 2 mēneši atpakaļ
vecāks
revīzija
bc482cbf91
100 mainītis faili ar 705 papildinājumiem un 272 dzēšanām
  1. 2 1
      .gitignore
  2. 8 5
      csrc/flash_attn/flash_api.cpp
  3. 3 2
      csrc/flash_attn/src/alibi.h
  4. 3 2
      csrc/flash_attn/src/block_info.h
  5. 5 4
      csrc/flash_attn/src/dropout.h
  6. 5 0
      csrc/flash_attn/src/flash.h
  7. 5 1
      csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu
  8. 5 1
      csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu
  9. 5 1
      csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu
  10. 5 1
      csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu
  11. 5 1
      csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu
  12. 5 1
      csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu
  13. 5 1
      csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu
  14. 5 1
      csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu
  15. 5 1
      csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu
  16. 5 1
      csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu
  17. 5 1
      csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu
  18. 5 1
      csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu
  19. 5 1
      csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu
  20. 5 1
      csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu
  21. 5 1
      csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu
  22. 5 1
      csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu
  23. 5 1
      csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu
  24. 5 1
      csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu
  25. 5 1
      csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu
  26. 5 1
      csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu
  27. 5 1
      csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu
  28. 5 1
      csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu
  29. 5 1
      csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu
  30. 5 1
      csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu
  31. 5 1
      csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu
  32. 5 1
      csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu
  33. 5 1
      csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu
  34. 5 1
      csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu
  35. 47 46
      csrc/flash_attn/src/flash_bwd_kernel.h
  36. 11 6
      csrc/flash_attn/src/flash_bwd_launch_template.h
  37. 14 13
      csrc/flash_attn/src/flash_bwd_preprocess_kernel.h
  38. 5 1
      csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu
  39. 5 1
      csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu
  40. 5 1
      csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu
  41. 5 1
      csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu
  42. 5 1
      csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu
  43. 5 1
      csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu
  44. 5 1
      csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu
  45. 5 1
      csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu
  46. 5 1
      csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu
  47. 5 1
      csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu
  48. 5 1
      csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu
  49. 5 1
      csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu
  50. 5 1
      csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu
  51. 5 1
      csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu
  52. 5 1
      csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu
  53. 5 1
      csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
  54. 5 1
      csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu
  55. 5 1
      csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
  56. 5 1
      csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu
  57. 5 1
      csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
  58. 5 1
      csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu
  59. 5 1
      csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
  60. 5 1
      csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu
  61. 5 1
      csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
  62. 5 1
      csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu
  63. 5 1
      csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
  64. 5 1
      csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu
  65. 5 1
      csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
  66. 69 68
      csrc/flash_attn/src/flash_fwd_kernel.h
  67. 7 3
      csrc/flash_attn/src/flash_fwd_launch_template.h
  68. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu
  69. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu
  70. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu
  71. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu
  72. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu
  73. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu
  74. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu
  75. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu
  76. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu
  77. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu
  78. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu
  79. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu
  80. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu
  81. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu
  82. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu
  83. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu
  84. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu
  85. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu
  86. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu
  87. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu
  88. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu
  89. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
  90. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu
  91. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
  92. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu
  93. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
  94. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu
  95. 5 1
      csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu
  96. 33 31
      csrc/flash_attn/src/generate_kernels.py
  97. 4 3
      csrc/flash_attn/src/mask.h
  98. 67 0
      csrc/flash_attn/src/namespace_config.h
  99. 4 2
      csrc/flash_attn/src/philox.cuh
  100. 3 2
      csrc/flash_attn/src/rotary.h

+ 2 - 1
.gitignore

@@ -22,9 +22,10 @@ var/
 *.egg-info/
 .installed.cfg
 *.egg
+.eggs/
 
 # IDE-related
 .idea/
 
 # Dev
-venv
+venv

+ 8 - 5
csrc/flash_attn/flash_api.cpp

@@ -12,6 +12,7 @@
 
 #include <cutlass/numeric_types.h>
 
+#include "namespace_config.h"
 #include "hardware_info.h"
 #include "flash.h"
 #include "static_switch.h"
@@ -20,6 +21,7 @@
 #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
 #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
 
+namespace FLASH_NAMESPACE {
 
 void set_params_fprop(Flash_fwd_params &params,
                       // sizes
@@ -1471,12 +1473,13 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
     }
     return {out, softmax_lse};
 }
+} // namespace FLASH_NAMESPACE
 
 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
     m.doc() = "FlashAttention";
-    m.def("fwd", &mha_fwd, "Forward pass");
-    m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
-    m.def("bwd", &mha_bwd, "Backward pass");
-    m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
-    m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
+    m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass");
+    m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass (variable length)");
+    m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass");
+    m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass (variable length)");
+    m.def("fwd_kvcache", &FLASH_NAMESPACE::mha_fwd_kvcache, "Forward pass, with KV-cache");
 }

+ 3 - 2
csrc/flash_attn/src/alibi.h

@@ -1,5 +1,6 @@
 #include <cmath>
 
+#include "namespace_config.h"
 #include <cute/tensor.hpp>
 
 #include <cutlass/cutlass.h>
@@ -7,7 +8,7 @@
 
 #include "utils.h"
 
-namespace flash {
+namespace FLASH_NAMESPACE {
 
 using namespace cute;
 
@@ -71,4 +72,4 @@ struct Alibi {
 
 };
 
-}  // namespace flash
+}  // namespace FLASH_NAMESPACE

+ 3 - 2
csrc/flash_attn/src/block_info.h

@@ -4,7 +4,8 @@
 
 #pragma once
 
-namespace flash {
+#include "namespace_config.h"
+namespace FLASH_NAMESPACE {
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
@@ -45,4 +46,4 @@ struct BlockInfo {
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-}  // namespace flash
+}  // namespace FLASH_NAMESPACE

+ 5 - 4
csrc/flash_attn/src/dropout.h

@@ -4,10 +4,11 @@
 
 #pragma once
 
+#include "namespace_config.h"
 #include "philox.cuh"
 #include "utils.h"
 
-namespace flash {
+namespace FLASH_NAMESPACE {
 
 struct Dropout {
 
@@ -26,7 +27,7 @@ struct Dropout {
     __forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_,
                                          int block_row_start, int block_col_start, int block_row_stride) {
         // convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
-        Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_dropout(tensor_.layout()));
+        Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_dropout(tensor_.layout()));
         using T = typename Engine::value_type;
         auto encode_dropout = [](bool keep, T val) {
             return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
@@ -41,7 +42,7 @@ struct Dropout {
             #pragma unroll
             for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
                 // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
-                uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
+                uint4 random_uint4 = FLASH_NAMESPACE::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
                 // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
                 uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
                 // Special implementation for 16-bit types: we duplicate the threshold to the
@@ -91,4 +92,4 @@ struct Dropout {
 
 };
 
-} // namespace flash
+} // namespace FLASH_NAMESPACE

+ 5 - 0
csrc/flash_attn/src/flash.h

@@ -4,11 +4,14 @@
 
 #pragma once
 
+#include "namespace_config.h"
+
 #include <cuda.h>
 #include <vector>
 
 #include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
 
+namespace FLASH_NAMESPACE {
 constexpr int TOTAL_DIM = 0;
 constexpr int H_DIM = 1;
 constexpr int D_DIM = 2;
@@ -187,3 +190,5 @@ template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_pa
 template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
 
 template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
+
+}  // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim128<cutlass::bfloat16_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim128<cutlass::bfloat16_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim128<cutlass::half_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim128<cutlass::half_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::bfloat16_t, 160, true>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim160<cutlass::bfloat16_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::bfloat16_t, 160, false>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim160<cutlass::bfloat16_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::half_t, 160, true>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim160<cutlass::half_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::half_t, 160, false>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim160<cutlass::half_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim192<cutlass::bfloat16_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::bfloat16_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim192<cutlass::bfloat16_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::half_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim192<cutlass::half_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::half_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim192<cutlass::half_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::bfloat16_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim256<cutlass::bfloat16_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::bfloat16_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim256<cutlass::bfloat16_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::half_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim256<cutlass::half_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::half_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim256<cutlass::half_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::bfloat16_t, 32, true>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim32<cutlass::bfloat16_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::bfloat16_t, 32, false>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim32<cutlass::bfloat16_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::half_t, 32, true>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim32<cutlass::half_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::half_t, 32, false>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim32<cutlass::half_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::bfloat16_t, 64, true>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim64<cutlass::bfloat16_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::bfloat16_t, 64, false>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim64<cutlass::bfloat16_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::half_t, 64, true>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim64<cutlass::half_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::half_t, 64, false>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim64<cutlass::half_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::bfloat16_t, 96, true>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim96<cutlass::bfloat16_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::bfloat16_t, 96, false>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim96<cutlass::bfloat16_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::half_t, 96, true>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim96<cutlass::half_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_bwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_bwd_<cutlass::half_t, 96, false>(Flash_bwd_params &params, cudaStream_t stream) {
     run_mha_bwd_hdim96<cutlass::half_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 47 - 46
csrc/flash_attn/src/flash_bwd_kernel.h

@@ -4,6 +4,7 @@
 
 #pragma once
 
+#include "namespace_config.h"
 #include <cute/tensor.hpp>
 
 #include <cutlass/cutlass.h>
@@ -19,7 +20,7 @@
 
 #include "alibi.h"
 
-namespace flash {
+namespace FLASH_NAMESPACE {
 
 using namespace cute;
 
@@ -352,10 +353,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
         #pragma unroll
         for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
         // Clear_OOB_K must be false since we don't want to write zeros to gmem
-        flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
+        FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
             gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
         );
-        flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
+        FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
             gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
         );
         return;
@@ -371,28 +372,28 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
 
     if (Kernel_traits::Is_V_in_regs) {
         // Clear the smem tiles to account for predicated off loads
-        flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
+        FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
             gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
         );
-        flash::cp_async_fence();
+        FLASH_NAMESPACE::cp_async_fence();
     }
 
     Tensor tdOrdO = make_fragment_like(tdOgdO);
     Tensor tdOrO = make_fragment_like(tdOgO);
     if (!Is_first) {
         // Clear the smem tiles to account for predicated off loads
-        flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
+        FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
             gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
         );
     } else {
-        flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
+        FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
             gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
         );
-        flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
+        FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
             gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
         );
     }
-    flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
+    FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
         gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
     );
 
@@ -417,15 +418,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
     // cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK);
     // // if (cute::thread(1, 0)) { print(tKrK); }
 
-    flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
+    FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
         gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
     );
     if (!Kernel_traits::Is_V_in_regs) {
-        flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
+        FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
             gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
         );
     }
-    flash::cp_async_fence();
+    FLASH_NAMESPACE::cp_async_fence();
 
     // if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); }
     if (Is_first) {
@@ -442,14 +443,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
         cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view);
     }
 
-    flash::Dropout dropout(params.rng_state[0], params.rng_state[1], params.p_dropout_in_uint8_t,
+    FLASH_NAMESPACE::Dropout dropout(params.rng_state[0], params.rng_state[1], params.p_dropout_in_uint8_t,
                            bidb, bidh, tidx, params.h);
 
     clear(acc_dv);
     clear(acc_dk);
 
     const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
-    flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
+    FLASH_NAMESPACE::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
 
     for (; m_block >= m_block_min; --m_block) {
         Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_N, MMA_N)
@@ -468,21 +469,21 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
         //     cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k));
         // }
         // if (cute::thread0()) { print(tSrK); }
-        flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp,
+        FLASH_NAMESPACE::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp,
                     smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV);
 
         if constexpr (Is_softcap) {
-            flash::apply_softcap(acc_s, params.softcap);
+            FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap);
         }
 
         // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N))
-        Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
+        Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout()));
         // if (cute::thread(32, 0)) { print(scores); }
 
         // Softcapping - calculating dTanh and scaling dS later with it
         [[maybe_unused]] Tensor dtanh = make_tensor_like(scores);
         if constexpr (Is_softcap) {
-            flash::calculate_dtanh(scores, dtanh, params.softcap);
+            FLASH_NAMESPACE::calculate_dtanh(scores, dtanh, params.softcap);
         }
 
         // Alibi
@@ -500,7 +501,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
         // So we need to mask out the elements beyond actual_seqlen_k.
         if (!Is_causal && !Is_local) {
             if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) {
-                flash::apply_mask(scores, binfo.actual_seqlen_k,
+                FLASH_NAMESPACE::apply_mask(scores, binfo.actual_seqlen_k,
                                   n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16);
             }
         } else if (Is_causal) {
@@ -510,7 +511,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
             // But we still want to mask out elements beyond actual_seqlen_k.
             if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k
                 || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
-                flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
+                FLASH_NAMESPACE::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
                                          binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
                                          binfo.actual_seqlen_q,
                                          // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
@@ -520,7 +521,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
             if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right
                 || (m_block + 1) * kBlockM >= n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left
                 || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
-                flash::apply_mask_local(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
+                FLASH_NAMESPACE::apply_mask_local(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
                                         binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
                                         binfo.actual_seqlen_q, AtomLayoutMS * 16,
                                         params.window_size_left, params.window_size_right);
@@ -530,7 +531,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
 
         // if (cute::thread(32, 0)) { print(scores); }
         // Compute the exponential value.
-        flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
+        FLASH_NAMESPACE::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
         if constexpr (Is_dropout) {
             int warp_id = tidx / 32;
             int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
@@ -543,11 +544,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
         }
         // Convert scores from fp32 to fp16/bf16
         Tensor rP = !Is_dropout
-            ? flash::convert_type<Element>(acc_s)
-            : flash::convert_type_relu<Element>(acc_s);
+            ? FLASH_NAMESPACE::convert_type<Element>(acc_s)
+            : FLASH_NAMESPACE::convert_type_relu<Element>(acc_s);
         // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_N, MMA_N / 2)
         // if using m16n8k16 or (4, MMA_N, MMA_N) if using m16n8k8.
-        Tensor tPrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMmaSdP>(rP.layout()));
+        Tensor tPrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<Kernel_traits::TiledMmaSdP>(rP.layout()));
         Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP);     // ((Atom,AtomNum), MMA_N, MMA_N)
         cute::copy(smem_tiled_copy_PdS, tPaP, tPsP);
         // if (cute::thread0()) { print(tPaP); }
@@ -560,7 +561,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
         CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s));                     // MMA
 
         clear(acc_dp);
-        // Tensor acc_dp_reshaped = make_tensor(acc_dp.data(), flash::convert_layout_acc_rowcol(acc_dp.layout()));
+        // Tensor acc_dp_reshaped = make_tensor(acc_dp.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_dp.layout()));
         // #pragma unroll
         // for (int mi = 0; mi < size<0>(acc_dp_reshaped); ++mi) {
         //     #pragma unroll
@@ -571,7 +572,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
 
         // if (cute::thread0()) { print(dP_sum); }
 
-        flash::gemm</*A_in_regs=*/false, /*B_in_regs=*/Kernel_traits::Is_V_in_regs>(
+        FLASH_NAMESPACE::gemm</*A_in_regs=*/false, /*B_in_regs=*/Kernel_traits::Is_V_in_regs>(
             acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp,
             smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV
         );
@@ -612,13 +613,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
             tSsQ.data() = tSsQ.data() + sQ_offset;
             // Advance gQ
             tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride));
-            flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ);
-            flash::cp_async_fence();
+            FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ);
+            FLASH_NAMESPACE::cp_async_fence();
         }
 
         Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout());
         // Convert dS from fp32 to fp16
-        Tensor tdSrdS = flash::convert_type<Element>(dS_reshaped);
+        Tensor tdSrdS = FLASH_NAMESPACE::convert_type<Element>(dS_reshaped);
         // if (cute::thread0()) { print(tPrP); }
         Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS);                                          // ((Atom,AtomNum), MMA_N, MMA_N)
         cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
@@ -626,10 +627,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
 
         // Layout p_l = tPrP.layout();
         // Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l)));
-        // flash::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt);
+        // FLASH_NAMESPACE::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt);
         // Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout());
-        // flash::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt);
-        flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv,
+        // FLASH_NAMESPACE::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt);
+        FLASH_NAMESPACE::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv,
                     smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
         // if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); }
         // if (cute::thread0()) { print(acc_dv); }
@@ -641,15 +642,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
             tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride));
             if (Is_first) {
                 tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride));
-                flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ);
-                flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ);
+                FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ);
+                FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ);
             } else {
-                flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ);
-                flash::cp_async_fence();
+                FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ);
+                FLASH_NAMESPACE::cp_async_fence();
             }
         }
 
-        flash::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq,
+        FLASH_NAMESPACE::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq,
                     smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt);
         // if (cute::thread0()) { print(acc_dq); }
 
@@ -678,12 +679,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
             #pragma unroll
             for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }
             // Convert acc_dq from fp32 to fp16
-            Tensor rdQ = flash::convert_type<Element>(acc_dq);
+            Tensor rdQ = FLASH_NAMESPACE::convert_type<Element>(acc_dq);
             Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ);  // ((Atom,AtomNum), MMA_N, MMA_N)
             cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
         }
 
-        flash::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv,
+        FLASH_NAMESPACE::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv,
                     smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
         // if (cute::thread0()) { print(acc_dk); }
         if (Double_buffer) {  // Double buffer for sQ
@@ -693,8 +694,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
             __syncthreads();
             // Advance gQ
             tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride));
-            flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ);
-            flash::cp_async_fence();
+            FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ);
+            FLASH_NAMESPACE::cp_async_fence();
         }
 
         if (Is_first && m_block > m_block_min) {
@@ -730,8 +731,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
     for (int i = 0; i < size(acc_dk); ++i) { acc_dk(i) *= params.scale_softmax_rp_dropout; }
 
     // Convert acc_dv from fp32 to fp16
-    Tensor rdK = flash::convert_type<Element>(acc_dk);
-    Tensor rdV = flash::convert_type<Element>(acc_dv);
+    Tensor rdK = FLASH_NAMESPACE::convert_type<Element>(acc_dk);
+    Tensor rdV = FLASH_NAMESPACE::convert_type<Element>(acc_dv);
 
     Tensor sdK = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutdKV{});  // (SMEM_N, SMEM_K)
     Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
@@ -782,10 +783,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
     #pragma unroll
     for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
     // Clear_OOB_K must be false since we don't want to write zeros to gmem
-    flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
+    FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
         gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
     );
-    flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
+    FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
         gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
     );
 

+ 11 - 6
csrc/flash_attn/src/flash_bwd_launch_template.h

@@ -4,6 +4,7 @@
 
 #pragma once
 
+#include "namespace_config.h"
 #include <c10/cuda/CUDAException.h>  // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
 
 #include "static_switch.h"
@@ -12,6 +13,8 @@
 #include "flash_bwd_preprocess_kernel.h"
 #include "flash_bwd_kernel.h"
 
+namespace FLASH_NAMESPACE {
+
 // Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
 #define ARCH_SUPPORTS_FLASH
@@ -30,7 +33,7 @@ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)
 
 DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) {
     #if defined(ARCH_SUPPORTS_FLASH)
-       flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
+       FLASH_NAMESPACE::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
     #else
         FLASH_UNSUPPORTED_ARCH
     #endif
@@ -39,7 +42,7 @@ DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bo
 DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
     #if defined(ARCH_SUPPORTS_FLASH)
         static_assert(!(Is_causal && Is_local));  // If Is_local is true, Is_causal should be false
-        flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
+        FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
     #else
         FLASH_UNSUPPORTED_ARCH
     #endif
@@ -48,22 +51,22 @@ DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool
 
 template<bool Clear_dQaccum=true, typename Kernel_traits>
 __global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
-    flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
+    FLASH_NAMESPACE::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
 }
 
 template<typename Kernel_traits>
 __global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) {
-    flash::clear_dKVaccum<Kernel_traits>(params);
+    FLASH_NAMESPACE::clear_dKVaccum<Kernel_traits>(params);
 }
 
 template<typename Kernel_traits>
 __global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
-    flash::convert_dQ<Kernel_traits>(params, nsplits);
+    FLASH_NAMESPACE::convert_dQ<Kernel_traits>(params, nsplits);
 }
 
 template<typename Kernel_traits>
 __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
-    flash::convert_dKV<Kernel_traits>(params);
+    FLASH_NAMESPACE::convert_dKV<Kernel_traits>(params);
 }
 
 template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
@@ -321,3 +324,5 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
         }
     });
 }
+
+} // namespace FLASH_NAMESPACE {

+ 14 - 13
csrc/flash_attn/src/flash_bwd_preprocess_kernel.h

@@ -4,6 +4,7 @@
 
 #pragma once
 
+#include "namespace_config.h"
 #include <cute/tensor.hpp>
 
 #include <cutlass/cutlass.h>
@@ -14,7 +15,7 @@
 #include "kernel_traits.h"
 #include "utils.h"
 
-namespace flash {
+namespace FLASH_NAMESPACE {
 
 using namespace cute;
 
@@ -32,8 +33,8 @@ inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engi
                                                              make_layout(get<0>(do_.layout()),
                                                                          get<2>(do_.layout()))));
     Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout());
-    Tensor do_fp32 = flash::convert_type<float>(do_reshaped);
-    Tensor o_fp32 = flash::convert_type<float>(o_reshaped);
+    Tensor do_fp32 = FLASH_NAMESPACE::convert_type<float>(do_reshaped);
+    Tensor o_fp32 = FLASH_NAMESPACE::convert_type<float>(o_reshaped);
     #pragma unroll
     for (int mi = 0; mi < size<0>(do_reshaped); ++mi) {
         float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);
@@ -41,8 +42,8 @@ inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engi
         for (int ni = 1; ni < size<1>(do_reshaped); ni++) {
             dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);
         }
-        flash::SumOp<float> sum_op;
-        dP_sum_cur = flash::Allreduce<THREADS_PER_ROW>::run(dP_sum_cur, sum_op) * scale;
+        FLASH_NAMESPACE::SumOp<float> sum_op;
+        dP_sum_cur = FLASH_NAMESPACE::Allreduce<THREADS_PER_ROW>::run(dP_sum_cur, sum_op) * scale;
         if (threadIdx.x % THREADS_PER_ROW == 0) {
             dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur;
         }
@@ -116,10 +117,10 @@ inline __device__ void compute_dot_do_o(const Params &params) {
 
     Tensor tdOrdO = make_fragment_like(tdOgdO);
     Tensor tdOrO = make_fragment_like(tdOgO);
-    flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
+    FLASH_NAMESPACE::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
         gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
     );
-    flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
+    FLASH_NAMESPACE::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
         gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
     );
     // By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final
@@ -244,7 +245,7 @@ inline __device__ void convert_dQ(const Params &params, const int nsplits) {
     #pragma unroll
     for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }
     // Convert acc_dq from fp32 to fp16
-    Tensor rdQ = flash::convert_type<Element>(acc_dq);
+    Tensor rdQ = FLASH_NAMESPACE::convert_type<Element>(acc_dq);
     Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ);  // ((Atom,AtomNum), MMA_N, MMA_N)
     cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
     __syncthreads();
@@ -257,7 +258,7 @@ inline __device__ void convert_dQ(const Params &params, const int nsplits) {
     #pragma unroll
     for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; }
     // Clear_OOB_K must be false since we don't want to write zeros to gmem
-    flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
+    FLASH_NAMESPACE::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
         gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM
     );
 }
@@ -349,8 +350,8 @@ inline __device__ void convert_dKV(const Params &params) {
         acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout;
     }
     // Convert acc_dk from fp32 to fp16
-    Tensor rdK = flash::convert_type<Element>(acc_dk);
-    Tensor rdV = flash::convert_type<Element>(acc_dv);
+    Tensor rdK = FLASH_NAMESPACE::convert_type<Element>(acc_dk);
+    Tensor rdV = FLASH_NAMESPACE::convert_type<Element>(acc_dv);
     Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK);  // ((Atom,AtomNum), MMA_N, MMA_N)
     Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV);  // ((Atom,AtomNum), MMA_N, MMA_N)
     cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
@@ -367,10 +368,10 @@ inline __device__ void convert_dKV(const Params &params) {
     #pragma unroll
     for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
     // Clear_OOB_K must be false since we don't want to write zeros to gmem
-    flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
+    FLASH_NAMESPACE::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
         gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
     );
-    flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
+    FLASH_NAMESPACE::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
         gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
     );
 }

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim128<cutlass::bfloat16_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim128<cutlass::bfloat16_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim128<cutlass::half_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim128<cutlass::half_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::bfloat16_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim160<cutlass::bfloat16_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::bfloat16_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim160<cutlass::bfloat16_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::half_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim160<cutlass::half_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::half_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim160<cutlass::half_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::bfloat16_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim192<cutlass::bfloat16_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::bfloat16_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim192<cutlass::bfloat16_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::half_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim192<cutlass::half_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::half_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim192<cutlass::half_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::bfloat16_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim256<cutlass::bfloat16_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::bfloat16_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim256<cutlass::bfloat16_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::half_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim256<cutlass::half_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::half_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim256<cutlass::half_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::bfloat16_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim32<cutlass::bfloat16_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::bfloat16_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim32<cutlass::bfloat16_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::half_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim32<cutlass::half_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::half_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim32<cutlass::half_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::bfloat16_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim64<cutlass::bfloat16_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::bfloat16_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim64<cutlass::bfloat16_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::half_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim64<cutlass::half_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::half_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim64<cutlass::half_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::bfloat16_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim96<cutlass::bfloat16_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::bfloat16_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim96<cutlass::bfloat16_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::half_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim96<cutlass::half_t, true>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu

@@ -1,10 +1,14 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template<>
 void run_mha_fwd_<cutlass::half_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream) {
     run_mha_fwd_hdim96<cutlass::half_t, false>(params, stream);
 }
+
+} // namespace FLASH_NAMESPACE

+ 69 - 68
csrc/flash_attn/src/flash_fwd_kernel.h

@@ -4,6 +4,7 @@
 
 #pragma once
 
+#include "namespace_config.h"
 #include "philox_unpack.cuh" // For at::cuda::philox::unpack
 
 #include <cute/tensor.hpp>
@@ -20,7 +21,7 @@
 #include "dropout.h"
 #include "rotary.h"
 
-namespace flash {
+namespace FLASH_NAMESPACE {
 
 using namespace cute;
 
@@ -66,7 +67,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
     constexpr int kNWarps = Kernel_traits::kNWarps;
 
     auto seed_offset = at::cuda::philox::unpack(params.philox_args);
-    flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t,
+    FLASH_NAMESPACE::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t,
                            bidb, bidh, tidx, params.h);
 
     // Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might
@@ -115,7 +116,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
             for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
         }
         // Clear_OOB_K must be false since we don't want to write zeros to gmem
-        flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
+        FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
             gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
         );
         #pragma unroll
@@ -246,7 +247,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
     // Prologue
 
     // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
-    flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
+    FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
                                        binfo.actual_seqlen_q - m_block * kBlockM);
     if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
 
@@ -255,7 +256,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
     // // if (cute::thread0()) { print(sQNoSwizzle); }
 
     if (Kernel_traits::Share_Q_K_smem) {
-        flash::cp_async_wait<0>();
+        FLASH_NAMESPACE::cp_async_wait<0>();
         __syncthreads();
         Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
         CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view));            // M
@@ -265,14 +266,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
 
     int n_block = n_block_max - 1;
     // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
-    flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV,
+    FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV,
                                        binfo.actual_seqlen_k - n_block * kBlockN);
     cute::cp_async_fence();
     // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
     // __syncthreads();
 
     if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
-        flash::cp_async_wait<1>();
+        FLASH_NAMESPACE::cp_async_wait<1>();
         __syncthreads();
         Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
         CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view));            // M
@@ -281,10 +282,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
 
     clear(acc_o);
 
-    flash::Softmax<2 * size<1>(acc_o)> softmax;
+    FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax;
 
     const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
-    flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
+    FLASH_NAMESPACE::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
 
     // For performance reason, we separate out two kinds of iterations:
     // those that need masking on S, and those that don't.
@@ -301,37 +302,37 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
     for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
         Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_M, MMA_N)
         clear(acc_s);
-        flash::cp_async_wait<0>();
+        FLASH_NAMESPACE::cp_async_wait<0>();
         __syncthreads();
 
         // Advance gV
         if (masking_step > 0) {
-            flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
+            FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
         } else {
             // Clear the smem tiles to account for predicated off loads
-            flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
+            FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
                 gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
             );
         }
         cute::cp_async_fence();
 
-        flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
+        FLASH_NAMESPACE::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
             acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
             smem_thr_copy_Q, smem_thr_copy_K
         );
         // if (cute::thread0()) { print(acc_s); }
         if constexpr (Is_softcap){
-            flash::apply_softcap(acc_s, params.softcap);
+            FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap);
         }
 
         mask.template apply_mask<Is_causal, Is_even_MN>(
             acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
         );
 
-        flash::cp_async_wait<0>();
+        FLASH_NAMESPACE::cp_async_wait<0>();
         __syncthreads();
         if (n_block > n_block_min) {
-            flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);
+            FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);
             // This cp_async_fence needs to be in the if block, otherwise the synchronization
             // isn't right and we get race conditions.
             cute::cp_async_fence();
@@ -343,7 +344,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
             : softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2);
 
         // Convert acc_s from fp32 to fp16/bf16
-        Tensor rP = flash::convert_type<Element>(acc_s);
+        Tensor rP = FLASH_NAMESPACE::convert_type<Element>(acc_s);
         int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
         int block_col_idx = n_block * (kBlockN / 32);
         if (Return_softmax) {
@@ -361,9 +362,9 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
 
         // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
         // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
-        Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
+        Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
         // if (cute::thread0()) { print(tOrP); }
-        flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
+        FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
         // if (cute::thread0()) { print(scores); }
 
         // This check is at the end of the loop since we always have at least 1 iteration
@@ -377,23 +378,23 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
     for (; n_block >= n_block_min; --n_block) {
         Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_M, MMA_N)
         clear(acc_s);
-        flash::cp_async_wait<0>();
+        FLASH_NAMESPACE::cp_async_wait<0>();
         __syncthreads();
-        flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
+        FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
         cute::cp_async_fence();
 
-        flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
+        FLASH_NAMESPACE::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
             acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
             smem_thr_copy_Q, smem_thr_copy_K
         );
         if constexpr (Is_softcap){
-            flash::apply_softcap(acc_s, params.softcap);
+            FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap);
         }
 
-        flash::cp_async_wait<0>();
+        FLASH_NAMESPACE::cp_async_wait<0>();
         __syncthreads();
         if (n_block > n_block_min) {
-            flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);
+            FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);
             // This cp_async_fence needs to be in the if block, otherwise the synchronization
             // isn't right and we get race conditions.
             cute::cp_async_fence();
@@ -405,7 +406,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
 
         softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);
 
-        Tensor rP = flash::convert_type<Element>(acc_s);
+        Tensor rP = FLASH_NAMESPACE::convert_type<Element>(acc_s);
         int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
         int block_col_idx = n_block * (kBlockN / 32);
         if (Return_softmax) {
@@ -423,8 +424,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
 
         // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
         // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
-        Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
-        flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
+        Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
+        FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
     }
 
     // Epilogue
@@ -432,7 +433,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
     Tensor lse = softmax.template normalize_softmax_lse<Is_dropout>(acc_o, params.scale_softmax, params.rp_dropout);
 
     // Convert acc_o from fp32 to fp16/bf16
-    Tensor rO = flash::convert_type<Element>(acc_o);
+    Tensor rO = FLASH_NAMESPACE::convert_type<Element>(acc_o);
     Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{});    // (SMEM_M,SMEM_N)
     // Partition sO to match the accumulator partitioning
     auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
@@ -487,7 +488,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
         for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
     }
     // Clear_OOB_K must be false since we don't want to write zeros to gmem
-    flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
+    FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
         gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
     );
 }
@@ -563,7 +564,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
             for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
         }
         // Clear_OOB_K must be false since we don't want to write zeros to gmem
-        flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
+        FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
             gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
         );
         #pragma unroll
@@ -730,18 +731,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         auto tKgK_data = tKgK.data();
         auto tVgV_data = tVgV.data();
         for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) {
-            flash::copy_w_min_idx<Is_even_K>(
+            FLASH_NAMESPACE::copy_w_min_idx<Is_even_K>(
                 tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
             );
             tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
             if (params.rotary_dim == 0) {
-                flash::copy_w_min_idx<Is_even_K>(
+                FLASH_NAMESPACE::copy_w_min_idx<Is_even_K>(
                     tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
                 );
             } else {
                 if (params.is_rotary_interleaved) {
                     // Don't clear OOB_K because we're writing to global memory
-                    flash::copy_rotary_interleaved<Is_even_K, /*Clear_OOB_K=*/false>(
+                    FLASH_NAMESPACE::copy_rotary_interleaved<Is_even_K, /*Clear_OOB_K=*/false>(
                         tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
                         binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim
                     );
@@ -749,7 +750,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
                     tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2));
                 } else {
                     // Don't clear OOB_K because we're writing to global memory
-                    flash::copy_rotary_contiguous<Is_even_K, /*Clear_OOB_K=*/false>(
+                    FLASH_NAMESPACE::copy_rotary_contiguous<Is_even_K, /*Clear_OOB_K=*/false>(
                         tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
                         binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim
                     );
@@ -784,7 +785,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
     // Read Q from gmem to smem, optionally apply rotary embedding.
     if (!Append_KV || params.rotary_dim == 0) {
         // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
-        flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
+        FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
                                            binfo.actual_seqlen_q - m_block * kBlockM);
     } else {
         const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
@@ -807,12 +808,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
         Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
         if (params.is_rotary_interleaved) {
-            flash::copy_rotary_interleaved<Is_even_K>(
+            FLASH_NAMESPACE::copy_rotary_interleaved<Is_even_K>(
                 tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
                 0, params.d, params.rotary_dim
             );
         } else {
-            flash::copy_rotary_contiguous<Is_even_K>(
+            FLASH_NAMESPACE::copy_rotary_contiguous<Is_even_K>(
                 tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
                 0, params.d, params.rotary_dim
             );
@@ -821,21 +822,21 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
 
     int n_block = n_block_max - 1;
     // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
-    flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
+    FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
                                        binfo.actual_seqlen_k - n_block * kBlockN);
     cute::cp_async_fence();
 
-    // flash::cp_async_wait<0>();
+    // FLASH_NAMESPACE::cp_async_wait<0>();
     // __syncthreads();
     // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); }
     // __syncthreads();
 
     clear(acc_o);
 
-    flash::Softmax<2 * size<1>(acc_o)> softmax;
+    FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax;
 
     const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
-    flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
+    FLASH_NAMESPACE::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
 
     // For performance reason, we separate out two kinds of iterations:
     // those that need masking on S, and those that don't.
@@ -852,7 +853,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
     for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
         Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_M, MMA_N)
         clear(acc_s);
-        flash::cp_async_wait<0>();
+        FLASH_NAMESPACE::cp_async_wait<0>();
         __syncthreads();
 
         // Advance gV
@@ -866,22 +867,22 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
                 const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
                 tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
             }
-            flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
+            FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
         } else {
             // Clear the smem tiles to account for predicated off loads
-            flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
+            FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
                 gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
             );
         }
         cute::cp_async_fence();
 
-        flash::gemm(
+        FLASH_NAMESPACE::gemm(
             acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
             smem_thr_copy_Q, smem_thr_copy_K
         );
         // if (cute::thread0()) { print(acc_s); }
         if constexpr (Is_softcap){
-            flash::apply_softcap(acc_s, params.softcap);
+            FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap);
         }
 
 
@@ -889,7 +890,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
             acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
         );
 
-        flash::cp_async_wait<0>();
+        FLASH_NAMESPACE::cp_async_wait<0>();
         __syncthreads();
         // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }
         // __syncthreads();
@@ -905,7 +906,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
                 const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
                 tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
             }
-            flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
+            FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
             // This cp_async_fence needs to be in the if block, otherwise the synchronization
             // isn't right and we get race conditions.
             cute::cp_async_fence();
@@ -918,12 +919,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
 
         // Convert acc_s from fp32 to fp16/bf16
-        Tensor rP = flash::convert_type<Element>(acc_s);
+        Tensor rP = FLASH_NAMESPACE::convert_type<Element>(acc_s);
         // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
         // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
-        Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
+        Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
 
-        flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
+        FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
 
         // This check is at the end of the loop since we always have at least 1 iteration
         if (n_masking_steps > 1 && n_block <= n_block_min) {
@@ -936,7 +937,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
     for (; n_block >= n_block_min; --n_block) {
         Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_M, MMA_N)
         clear(acc_s);
-        flash::cp_async_wait<0>();
+        FLASH_NAMESPACE::cp_async_wait<0>();
         __syncthreads();
         // Advance gV
         if (block_table == nullptr) {
@@ -948,18 +949,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
             const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
             tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
         }
-        flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
+        FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
         cute::cp_async_fence();
 
-        flash::gemm(
+        FLASH_NAMESPACE::gemm(
             acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
             smem_thr_copy_Q, smem_thr_copy_K
         );
         if constexpr (Is_softcap){
-            flash::apply_softcap(acc_s, params.softcap);
+            FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap);
         }
 
-        flash::cp_async_wait<0>();
+        FLASH_NAMESPACE::cp_async_wait<0>();
         __syncthreads();
         if (n_block > n_block_min) {
             // Advance gK
@@ -972,7 +973,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
                 const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
                 tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
             }
-            flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
+            FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
             // This cp_async_fence needs to be in the if block, otherwise the synchronization
             // isn't right and we get race conditions.
             cute::cp_async_fence();
@@ -983,12 +984,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         );
         softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);
 
-        Tensor rP = flash::convert_type<Element>(acc_s);
+        Tensor rP = FLASH_NAMESPACE::convert_type<Element>(acc_s);
         // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
         // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
-        Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
+        Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
 
-        flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
+        FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
     }
 
     // Epilogue
@@ -1005,7 +1006,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
     >;
     auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma);
     auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
-    Tensor rO = flash::convert_type<ElementO>(acc_o);
+    Tensor rO = FLASH_NAMESPACE::convert_type<ElementO>(acc_o);
     Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO);        // ((Atom,AtomNum), MMA_M, MMA_N)
     Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum);     // ((Atom,AtomNum),PIPE_M,PIPE_N)
 
@@ -1064,7 +1065,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
     }
     // Clear_OOB_K must be false since we don't want to write zeros to gmem
-    flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
+    FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
         gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
     );
 }
@@ -1087,7 +1088,7 @@ inline __device__ void compute_attn(const Params &params) {
     // the attention matrix. This way, as long as we have the batch, head, and the location of
     // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
 
-    flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params, bidb, bidh, m_block);
+    FLASH_NAMESPACE::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params, bidb, bidh, m_block);
 }
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -1101,7 +1102,7 @@ inline __device__ void compute_attn_splitkv(const Params &params) {
     const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
     const int n_split_idx = Split ? blockIdx.y : 0;
     const int num_n_splits = Split ? gridDim.y : 1;
-    flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
+    FLASH_NAMESPACE::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
 }
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -1242,7 +1243,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
     }
     // Load Oaccum in then scale and accumulate to O
     for (int split = 0; split < params.num_splits; ++split) {
-        flash::copy</*Is_even_MN=*/false, Is_even_K>(
+        FLASH_NAMESPACE::copy</*Is_even_MN=*/false, Is_even_K>(
             gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM
         );
         #pragma unroll
@@ -1262,7 +1263,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
     }
     // if (cute::thread0()) { print_tensor(tOrO); }
 
-    Tensor rO = flash::convert_type<Element>(tOrO);
+    Tensor rO = FLASH_NAMESPACE::convert_type<Element>(tOrO);
     // Write to gO
     #pragma unroll
     for (int m = 0; m < size<1>(rO); ++m) {
@@ -1290,4 +1291,4 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
     }
 }
 
-} // namespace flash
+} // namespace FLASH_NAMESPACE

+ 7 - 3
csrc/flash_attn/src/flash_fwd_launch_template.h

@@ -3,6 +3,7 @@
  ******************************************************************************/
 
 #pragma once
+#include "namespace_config.h"
 #include <c10/cuda/CUDAException.h>  // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
 
 #include "static_switch.h"
@@ -10,6 +11,8 @@
 #include "flash.h"
 #include "flash_fwd_kernel.h"
 
+namespace FLASH_NAMESPACE {
+
 // Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
 #define ARCH_SUPPORTS_FLASH
@@ -29,7 +32,7 @@ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
 DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {
     #if defined(ARCH_SUPPORTS_FLASH)
         static_assert(!(Is_causal && Is_local)); // Enforce constraints
-        flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);
+        FLASH_NAMESPACE::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);
     #else
         FLASH_UNSUPPORTED_ARCH
     #endif
@@ -37,7 +40,7 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, b
 
 DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) {
     #if defined(ARCH_SUPPORTS_FLASH)
-        flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params);
+        FLASH_NAMESPACE::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params);
     #else
         FLASH_UNSUPPORTED_ARCH
     #endif
@@ -45,7 +48,7 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_lo
 
 DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) {
     static_assert(Log_max_splits >= 1);
-    flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
+    FLASH_NAMESPACE::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
 }
 
 template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
@@ -327,3 +330,4 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
         // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
     });
 }
+}  // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 5 - 1
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu

@@ -1,7 +1,11 @@
 // Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
 // This file is auto-generated. See "generate_kernels.py"
-
+#include "namespace_config.h"
 #include "flash_fwd_launch_template.h"
 
+namespace FLASH_NAMESPACE {
+
 template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream);
+
+} // namespace FLASH_NAMESPACE

+ 33 - 31
csrc/flash_attn/src/generate_kernels.py

@@ -1,8 +1,3 @@
-# Copied from Driss Guessous's PR in PyTorch: https://github.com/pytorch/pytorch/pull/105602
-
-# This file is run to generate the kernel instantiations for the flash_attn kernels
-# They are written to several files in order to speed up compilation
-
 import argparse
 import itertools
 from dataclasses import dataclass
@@ -17,27 +12,40 @@ DTYPE_MAP = {
 SM = [80]  # Sm80 kernels support up to
 HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 256]
 IS_CAUSAL = ["false", "true"]
-KERNEL_IMPL_TEMPLATE_FWD = """#include "flash_fwd_launch_template.h"
+NAMESPACE_INCLUDE = '#include "namespace_config.h"\n'
+
+def get_fwd_template() -> str:
+    return NAMESPACE_INCLUDE + """#include "flash_fwd_launch_template.h"
+
+namespace FLASH_NAMESPACE {{
 
 template<>
 void run_mha_fwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream) {{
     run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream);
 }}
-"""
 
-KERNEL_IMPL_TEMPLATE_FWD_SPLIT = """#include "flash_fwd_launch_template.h"
+}} // namespace FLASH_NAMESPACE"""
+
+def get_fwd_split_template() -> str:
+    return NAMESPACE_INCLUDE + """#include "flash_fwd_launch_template.h"
+
+namespace FLASH_NAMESPACE {{
 
 template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream);
-"""
 
-KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h"
+}} // namespace FLASH_NAMESPACE"""
+
+def get_bwd_template() -> str:
+    return NAMESPACE_INCLUDE + """#include "flash_bwd_launch_template.h"
+
+namespace FLASH_NAMESPACE {{
 
 template<>
 void run_mha_bwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_bwd_params &params, cudaStream_t stream) {{
     run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream);
 }}
-"""
 
+}} // namespace FLASH_NAMESPACE"""
 
 @dataclass
 class Kernel:
@@ -49,37 +57,33 @@ class Kernel:
 
     @property
     def template(self) -> str:
-        if self.direction == "fwd":
-            return KERNEL_IMPL_TEMPLATE_FWD.format(
-                DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal
-            )
-        elif self.direction == "bwd":
-            return KERNEL_IMPL_TEMPLATE_BWD.format(
-                DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal
-            )
-        else:
-            return KERNEL_IMPL_TEMPLATE_FWD_SPLIT.format(
-                DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal
-            )
+        template_funcs = {
+            "fwd": get_fwd_template,
+            "bwd": get_bwd_template,
+            "fwd_split": get_fwd_split_template
+        }
+        template_func = template_funcs[self.direction]
+        return template_func().format(
+            DTYPE=DTYPE_MAP[self.dtype],
+            HEAD_DIM=self.head_dim,
+            IS_CAUSAL=self.is_causal
+        )
 
     @property
     def filename(self) -> str:
         return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_{'causal_' if self.is_causal == 'true' else ''}sm{self.sm}.cu"
 
-
 def get_all_kernels() -> List[Kernel]:
     for direction in ["fwd", "fwd_split", "bwd"]:
         for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM):
             yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction)
 
-
 def write_kernel(kernel: Kernel, autogen_dir: Path) -> None:
     prelude = """// Copyright (c) 2024, Tri Dao.
 // Splitting the different head dimensions to different files to speed up compilation.
-// This file is auto-generated. See "generate_kernels.py"\n
-"""
-    (autogen_dir / kernel.filename).write_text(prelude + kernel.template)
-
+// This file is auto-generated. See "generate_kernels.py"\n"""
+    content = prelude + kernel.template
+    (autogen_dir / kernel.filename).write_text(content)
 
 def main(output_dir: Optional[str]) -> None:
     if output_dir is None:
@@ -90,13 +94,11 @@ def main(output_dir: Optional[str]) -> None:
     for kernel in get_all_kernels():
         write_kernel(kernel, output_dir)
 
-
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(
         prog="generate_kernels",
         description="Generate the flash_attention kernels template instantiations",
     )
-    # Set an optional output directory
     parser.add_argument(
         "-o",
         "--output_dir",

+ 4 - 3
csrc/flash_attn/src/mask.h

@@ -3,10 +3,11 @@
  ******************************************************************************/
 
 #pragma once
+#include "namespace_config.h"
 
 #include <cute/tensor.hpp>
 
-namespace flash {
+namespace FLASH_NAMESPACE {
 
 using namespace cute;
 
@@ -137,7 +138,7 @@ struct Mask {
         // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); }
         if constexpr (Need_masking) {
             // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
-            Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout()));
+            Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tensor_.layout()));
             // Do we need both row and column indices, or just column incides?
             static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;
             const int lane_id = threadIdx.x % 32;
@@ -210,4 +211,4 @@ struct Mask {
 
 };
 
-} // namespace flash
+} // namespace FLASH_NAMESPACE

+ 67 - 0
csrc/flash_attn/src/namespace_config.h

@@ -0,0 +1,67 @@
+/**
+ * @file flash_namespace_config.h
+ * @brief Configuration file for Flash namespace management and isolation
+ *
+ * This header provides configuration macros for managing the Flash namespace
+ * across a codebase. It allows for flexible namespace naming and provides
+ * utilities for namespace declaration and scoping.
+ *
+ * Usage Examples:
+ *
+ * 1. Basic namespace wrapping:
+ * @code
+ *   BEGIN_FLASH_NAMESPACE
+ *   class FlashDevice {
+ *     // Implementation
+ *   };
+ *   END_FLASH_NAMESPACE
+ * @endcode
+ *
+ * 2. Accessing types within the namespace:
+ * @code
+ *   FLASH_NAMESPACE_ALIAS(FlashDevice) device;
+ * @endcode
+ *
+ * 3. Defining content within namespace scope:
+ * @code
+ *   FLASH_NAMESPACE_SCOPE(
+ *     struct Configuration {
+ *       uint32_t size;
+ *       bool enabled;
+ *     };
+ *   )
+ * @endcode
+ *
+ * 4. Custom namespace name:
+ * @code
+ *   #define FLASH_NAMESPACE custom_flash
+ *   #include "flash_namespace_config.h"
+ * @endcode
+ *
+ * Configuration:
+ * - The default namespace is 'flash' if FLASH_NAMESPACE is not defined
+ * - Define FLASH_NAMESPACE before including this header to customize the
+ * namespace name
+ *
+ * Best Practices:
+ * - Include this header in all files that need access to the Flash namespace
+ *
+ */
+#pragma once
+
+#ifndef FLASH_NAMESPACE_CONFIG_H
+#define FLASH_NAMESPACE_CONFIG_H
+
+// Set default namespace to flash
+#ifndef FLASH_NAMESPACE
+#define FLASH_NAMESPACE flash
+#endif
+
+#define FLASH_NAMESPACE_ALIAS(name) FLASH_NAMESPACE::name
+
+#define FLASH_NAMESPACE_SCOPE(content)                                         \
+  namespace FLASH_NAMESPACE {                                                  \
+  content                                                                      \
+  }
+
+#endif // FLASH_NAMESPACE_CONFIG_H

+ 4 - 2
csrc/flash_attn/src/philox.cuh

@@ -2,7 +2,9 @@
 #pragma once
 // Philox CUDA.
 
-namespace flash {
+#include "namespace_config.h"
+
+namespace FLASH_NAMESPACE {
 
 struct ull2 {
     unsigned long long x;
@@ -48,4 +50,4 @@ __forceinline__ __device__ uint4 philox(unsigned long long seed,
     return output;
 }
 
-} // namespace flash
+} // namespace FLASH_NAMESPACE

+ 3 - 2
csrc/flash_attn/src/rotary.h

@@ -6,11 +6,12 @@
 
 #include <cute/tensor.hpp>
 
+#include "namespace_config.h"
 #include "utils.h"
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-namespace flash {
+namespace FLASH_NAMESPACE {
 
 using namespace cute;
 
@@ -149,4 +150,4 @@ __forceinline__ __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0>
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-}  // namespace flash
+}  // namespace FLASH_NAMESPACE

Daži faili netika attēloti, jo izmaiņu fails ir pārāk liels