Kaynağa Gözat

Preprocessor switches to control functionality (#788)

For faster and smaller builds in some simple cases,
provide switches to allow disabling
-backward
-alibi
-uneven k
-dropout
-local attention

Co-authored-by: Jeremy Francis Reizenstein <bottler@users.noreply.github.com>
Jeremy Reizenstein 1 yıl önce
ebeveyn
işleme
0658e320f6

+ 44 - 55
csrc/flash_attn/flash_api.cpp

@@ -112,6 +112,9 @@ void set_params_fprop(Flash_fwd_params &params,
     params.rp_dropout = 1.f / params.p_dropout;
     params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
     TORCH_CHECK(p_dropout < 1.f);
+    #ifdef FLASHATTENTION_DISABLE_DROPOUT
+        TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
+    #endif
 
     // Causal is the special case where window_size_right == 0 and window_size_left < 0.
     // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
@@ -122,7 +125,16 @@ void set_params_fprop(Flash_fwd_params &params,
     params.window_size_left = window_size_left;
     params.window_size_right = window_size_right;
 
+    #ifdef FLASHATTENTION_DISABLE_LOCAL
+        TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
+            "This flash attention build does not support local attention.");
+    #endif
+
     params.is_seqlens_k_cumulative = true;
+
+    #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
+        TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
+    #endif
 }
 
 void set_params_dgrad(Flash_bwd_params &params,
@@ -282,6 +294,25 @@ void set_params_splitkv(Flash_fwd_params &params, const int batch_size,
     }
 }
 
+void set_params_alibi(Flash_fwd_params &params, c10::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads){
+#ifdef FLASHATTENTION_DISABLE_ALIBI
+    TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi.");
+    params.alibi_slopes_ptr = nullptr;
+#else
+    if (alibi_slopes_.has_value()) {
+        auto alibi_slopes = alibi_slopes_.value();
+        TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
+        CHECK_DEVICE(alibi_slopes);
+        TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
+        TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
+        params.alibi_slopes_ptr = alibi_slopes.data_ptr();
+        params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
+    } else {
+        params.alibi_slopes_ptr = nullptr;
+    }
+#endif
+}
+
 std::vector<at::Tensor>
 mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
         const at::Tensor &k,         // batch_size x seqlen_k x num_heads_k x head_size
@@ -435,17 +466,7 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
         params.philox_args = gen->philox_cuda_state(counter_offset);
     }
 
-    if (alibi_slopes_.has_value()) {
-        auto alibi_slopes = alibi_slopes_.value();
-        TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
-        CHECK_DEVICE(alibi_slopes);
-        TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
-        TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
-        params.alibi_slopes_ptr = alibi_slopes.data_ptr();
-        params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
-    } else {
-        params.alibi_slopes_ptr = nullptr;
-    }
+    set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
 
     if (seqlen_k > 0) {
         auto stream = at::cuda::getCurrentCUDAStream().stream();
@@ -657,17 +678,7 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
         params.philox_args = gen->philox_cuda_state(counter_offset);
     }
 
-    if (alibi_slopes_.has_value()) {
-        auto alibi_slopes = alibi_slopes_.value();
-        TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
-        CHECK_DEVICE(alibi_slopes);
-        TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
-        TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
-        params.alibi_slopes_ptr = alibi_slopes.data_ptr();
-        params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
-    } else {
-        params.alibi_slopes_ptr = nullptr;
-    }
+    set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
 
     if (max_seqlen_k > 0) {
         auto stream = at::cuda::getCurrentCUDAStream().stream();
@@ -724,6 +735,9 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
         c10::optional<at::Generator> gen_,
         c10::optional<at::Tensor> &rng_state) {
 
+    #ifdef FLASHATTENTION_DISABLE_BACKWARD
+        TORCH_CHECK(false, "This flash attention build does not support backward.");
+    #endif
     if (is_causal) { window_size_right = 0; }
     auto dprops = at::cuda::getCurrentDeviceProperties();
     // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
@@ -903,17 +917,7 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
         params.rng_state[1] = std::get<1>(seeds);
     }
 
-    if (alibi_slopes_.has_value()) {
-        auto alibi_slopes = alibi_slopes_.value();
-        TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
-        CHECK_DEVICE(alibi_slopes);
-        TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
-        TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
-        params.alibi_slopes_ptr = alibi_slopes.data_ptr();
-        params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
-    } else {
-        params.alibi_slopes_ptr = nullptr;
-    }
+    set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
 
     if (seqlen_q > 0) {
         launch(params, stream);
@@ -963,6 +967,10 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
                c10::optional<at::Generator> gen_,
                c10::optional<at::Tensor> &rng_state) {
 
+    #ifdef FLASHATTENTION_DISABLE_BACKWARD
+        TORCH_CHECK(false, "This flash attention build does not support backward.");
+    #endif
+
     if (is_causal) { window_size_right = 0; }
     auto dprops = at::cuda::getCurrentDeviceProperties();
     // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
@@ -1158,17 +1166,7 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
         params.rng_state[1] = std::get<1>(seeds);
     }
 
-    if (alibi_slopes_.has_value()) {
-        auto alibi_slopes = alibi_slopes_.value();
-        TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
-        CHECK_DEVICE(alibi_slopes);
-        TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
-        TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
-        params.alibi_slopes_ptr = alibi_slopes.data_ptr();
-        params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
-    } else {
-        params.alibi_slopes_ptr = nullptr;
-    }
+    set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
 
     if (max_seqlen_q > 0) {
         launch(params, stream);
@@ -1435,17 +1433,8 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
     }
     params.page_block_size = page_block_size;
 
-    if (alibi_slopes_.has_value()) {
-        auto alibi_slopes = alibi_slopes_.value();
-        TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
-        CHECK_DEVICE(alibi_slopes);
-        TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
-        TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
-        params.alibi_slopes_ptr = alibi_slopes.data_ptr();
-        params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
-    } else {
-        params.alibi_slopes_ptr = nullptr;
-    }
+
+    set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
 
     auto stream = at::cuda::getCurrentCUDAStream().stream();
     // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx,

+ 13 - 11
csrc/flash_attn/src/flash_bwd_launch_template.h

@@ -69,9 +69,9 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
     // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
     BOOL_SWITCH(params.is_causal, Is_causal, [&] {
         BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
-            BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
-                BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
-                    BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
+            EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
+                LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
+                    ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
                         // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
                         // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
                         // If Is_local, set Is_causal to false
@@ -100,7 +100,9 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
 
 template<typename Kernel_traits, bool Is_dropout>
 void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
+#ifndef FLASHATTENTION_DISABLE_BACKWARD
     run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream);
+#endif
 }
 
 template<typename T>
@@ -114,7 +116,7 @@ void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
     if (status_ != cudaSuccess) {
       C10_CUDA_CHECK(status_);
     }
-    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
             if constexpr(!Is_dropout) {  // We can afford more registers to keep V in registers
                 run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
@@ -139,7 +141,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
       C10_CUDA_CHECK(status_);
     }
     // printf("max_smem_per_block = %d\n", max_smem_per_block);
-    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         // Changing AtomLayoutMdQ from 2 to 4 takes the same time
         // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream);
         // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
@@ -184,7 +186,7 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
       C10_CUDA_CHECK(status_);
     }
     // printf("max_smem_per_block = %d\n", max_smem_per_block);
-    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         if (max_smem_per_block >= 116 * 1024) {
             if constexpr(!Is_dropout) {  // 92KB
                 run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
@@ -210,7 +212,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
       C10_CUDA_CHECK(status_);
     }
     // printf("max_smem_per_block = %d\n", max_smem_per_block);
-    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
         // This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
         // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
@@ -243,7 +245,7 @@ void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream) {
     if (status_ != cudaSuccess) {
       C10_CUDA_CHECK(status_);
     }
-    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         if (max_smem_per_block >= 116 * 1024) {
             run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
         } else {
@@ -263,7 +265,7 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
     if (status_ != cudaSuccess) {
       C10_CUDA_CHECK(status_);
     }
-    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         if (max_smem_per_block >= 136 * 1024) {
             run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
         } else {
@@ -275,7 +277,7 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
 template<typename T>
 void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 224;
-    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
     });
 }
@@ -291,7 +293,7 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
     if (status_ != cudaSuccess) {
       C10_CUDA_CHECK(status_);
     }
-    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         if (max_smem_per_block >= 176 * 1024) {  // H100
             run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
         } else {  // A100, we don't do double buffering to save smem

+ 15 - 15
csrc/flash_attn/src/flash_fwd_launch_template.h

@@ -42,10 +42,10 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
     const bool is_even_K = params.d == Kernel_traits::kHeadDim;
     const bool return_softmax = params.p_ptr != nullptr;
     BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
-        BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
-            BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
+        EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
+            LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
                 BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
-                    BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
+                    ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
                         // Will only return softmax if dropout, to reduce compilation time.
                         // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
                         // If return_softmax, set IsEvenMNConst to false to reduce number of templates
@@ -83,11 +83,11 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
     const bool is_even_K = params.d == Kernel_traits::kHeadDim;
     BOOL_SWITCH(params.is_causal, Is_causal, [&] {
         BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
-            BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
-                BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
+            EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
+                LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
                     BOOL_SWITCH(params.num_splits > 1, Split, [&] {
                         BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
-                            BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
+                            ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
                                 // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
                                 // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
                                 // If Is_local, set Is_causal to false
@@ -113,7 +113,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
         // If headdim is divisible by 64, then we set kBlockM = 8, etc.
         constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
         dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
-        BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
+        EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
             if (params.num_splits <= 2) {
                 flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
             } else if (params.num_splits <= 4) {
@@ -147,7 +147,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream)
 template<typename T>
 void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 32;
-    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         BOOL_SWITCH(params.is_causal, Is_causal, [&] {
             run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
         });
@@ -157,7 +157,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
 template<typename T>
 void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 64;
-    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         BOOL_SWITCH(params.is_causal, Is_causal, [&] {
             if constexpr(!Is_dropout) {
                 // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
@@ -181,7 +181,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 96;
     auto dprops = at::cuda::getCurrentDeviceProperties();
     bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
-    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         BOOL_SWITCH(params.is_causal, Is_causal, [&] {
             // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
             if (is_sm8x) {
@@ -207,7 +207,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 128;
     auto dprops = at::cuda::getCurrentDeviceProperties();
     bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
-    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         BOOL_SWITCH(params.is_causal, Is_causal, [&] {
             if constexpr(!Is_dropout) {
                 // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
@@ -244,7 +244,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 160;
     auto dprops = at::cuda::getCurrentDeviceProperties();
     bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
-    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         BOOL_SWITCH(params.is_causal, Is_causal, [&] {
             // For A100, H100, 128 x 32 is the fastest.
             // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
@@ -272,7 +272,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
 template<typename T>
 void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 192;
-    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         BOOL_SWITCH(params.is_causal, Is_causal, [&] {
             if constexpr(!Is_dropout) {
                 run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
@@ -300,7 +300,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
       C10_CUDA_CHECK(status_);
     }
     // printf("max_smem_per_block = %d\n", max_smem_per_block);
-    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         BOOL_SWITCH(params.is_causal, Is_causal, [&] {
             if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) {  // 112 KB
                 run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
@@ -331,7 +331,7 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
       C10_CUDA_CHECK(status_);
     }
     // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
-    BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         BOOL_SWITCH(params.is_causal, Is_causal, [&] {
             // For A100, we want to run with 128 x 64 (128KB smem).
             // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.

+ 41 - 0
csrc/flash_attn/src/static_switch.h

@@ -14,6 +14,7 @@
 ///     some_function<BoolConst>(...);
 /// });
 /// ```
+
 #define BOOL_SWITCH(COND, CONST_NAME, ...)      \
   [&] {                                         \
     if (COND) {                                 \
@@ -25,6 +26,46 @@
     }                                           \
   }()
 
+#ifdef FLASHATTENTION_DISABLE_DROPOUT
+  #define DROPOUT_SWITCH(COND, CONST_NAME, ...) \
+  [&] {                                         \
+    constexpr static bool CONST_NAME = false;   \
+    return __VA_ARGS__();                       \
+  }()
+#else
+  #define DROPOUT_SWITCH BOOL_SWITCH
+#endif
+
+#ifdef FLASHATTENTION_DISABLE_ALIBI
+  #define ALIBI_SWITCH(COND, CONST_NAME, ...)   \
+  [&] {                                         \
+    constexpr static bool CONST_NAME = false;   \
+    return __VA_ARGS__();                       \
+  }()
+#else
+  #define ALIBI_SWITCH BOOL_SWITCH
+#endif
+
+#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
+  #define EVENK_SWITCH(COND, CONST_NAME, ...)   \
+  [&] {                                         \
+    constexpr static bool CONST_NAME = true;    \
+    return __VA_ARGS__();                       \
+  }()
+#else
+  #define EVENK_SWITCH BOOL_SWITCH
+#endif
+
+#ifdef FLASHATTENTION_DISABLE_LOCAL
+  #define LOCAL_SWITCH(COND, CONST_NAME, ...)   \
+  [&] {                                         \
+    constexpr static bool CONST_NAME = false;    \
+    return __VA_ARGS__();                       \
+  }()
+#else
+  #define LOCAL_SWITCH BOOL_SWITCH
+#endif
+
 #define FP16_SWITCH(COND, ...)               \
   [&] {                                      \
     if (COND) {                              \