Pārlūkot izejas kodu

Don't specialize for hdim 224 to speed up compilation

Tri Dao 8 mēneši atpakaļ
vecāks
revīzija
751c762c9c

+ 11 - 9
csrc/flash_attn/flash_api.cpp

@@ -443,7 +443,7 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
 
     auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
     const int head_size = round_multiple(head_size_og, 8);
-    const int head_size_rounded = round_multiple(head_size, 32);
+    const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
     const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
     const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
 
@@ -679,7 +679,7 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
 
     auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
     const int head_size = round_multiple(head_size_og, 8);
-    const int head_size_rounded = round_multiple(head_size, 32);
+    const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
     const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
     const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
 
@@ -874,17 +874,18 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
     TORCH_CHECK(batch_size > 0, "batch size must be positive");
     TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
     TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
-    if (head_size > 192 && (head_size <= 224 || is_dropout)) {
-        TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
+    if (head_size > 192 && is_dropout) {
+        TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 with dropout requires A100/A800 or H100/H800");
     }
     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
 
     auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
-    const int head_size_rounded = round_multiple(head_size, 32);
+    const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
     const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
     const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
 
     TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
+    if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
 
     if (window_size_left >= seqlen_k) { window_size_left = -1; }
     if (window_size_right >= seqlen_k) { window_size_right = -1; }
@@ -1114,13 +1115,14 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
     TORCH_CHECK(batch_size > 0, "batch size must be positive");
     TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
     TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
-    if (head_size > 192 && (head_size <= 224 || is_dropout)) {
-        TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
+    if (head_size > 192 && is_dropout) {
+        TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 with dropout requires A100/A800 or H100/H800");
     }
     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
+    if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
 
     auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
-    const int head_size_rounded = round_multiple(head_size, 32);
+    const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
     const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
     const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
 
@@ -1415,7 +1417,7 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
 
     auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
     const int head_size = round_multiple(head_size_og, 8);
-    const int head_size_rounded = round_multiple(head_size, 32);
+    const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
     const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
     const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
 

+ 0 - 8
csrc/flash_attn/src/flash_bwd_launch_template.h

@@ -299,14 +299,6 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
     });
 }
 
-template<typename T>
-void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream) {
-    constexpr static int Headdim = 224;
-    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
-        run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
-    });
-}
-
 template<typename T>
 void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 256;

+ 0 - 27
csrc/flash_attn/src/flash_fwd_launch_template.h

@@ -299,33 +299,6 @@ void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
     });
 }
 
-template<typename T, bool Is_causal>
-void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
-    constexpr static int Headdim = 224;
-    int device;
-    cudaGetDevice(&device);
-    int max_smem_per_block;
-    cudaError status_ = cudaDeviceGetAttribute(
-        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
-    if (status_ != cudaSuccess) {
-      C10_CUDA_CHECK(status_);
-    }
-    // printf("max_smem_per_block = %d\n", max_smem_per_block);
-    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
-        if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) {  // 112 KB
-            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
-        } else {
-            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-        }
-        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
-        // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
-        // If we have N = 32, there are only 1024 elements to load at once, where each load
-        // is 8 elements. This means we can only use 128 threads and not 256 threads.
-        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
-    });
-}
-
 template<typename T, bool Is_causal>
 void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 256;

+ 1 - 1
csrc/flash_attn/src/generate_kernels.py

@@ -15,7 +15,7 @@ DTYPE_MAP = {
 }
 
 SM = [80]  # Sm80 kernels support up to
-HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 224, 256]
+HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 256]
 IS_CAUSAL = ["false", "true"]
 KERNEL_IMPL_TEMPLATE_FWD = """#include "flash_fwd_launch_template.h"
 

+ 0 - 3
csrc/flash_attn/src/static_switch.h

@@ -107,9 +107,6 @@
     } else if (HEADDIM <= 192) {           \
       constexpr static int kHeadDim = 192; \
       return __VA_ARGS__();                \
-    } else if (HEADDIM <= 224) {           \
-      constexpr static int kHeadDim = 224; \
-      return __VA_ARGS__();                \
     } else if (HEADDIM <= 256) {           \
       constexpr static int kHeadDim = 256; \
       return __VA_ARGS__();                \

+ 0 - 10
setup.py

@@ -192,8 +192,6 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
                 "csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
-                "csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu",
-                "csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
@@ -208,8 +206,6 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
                 "csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
-                "csrc/flash_attn/src/flash_fwd_hdim224_fp16_causal_sm80.cu",
-                "csrc/flash_attn/src/flash_fwd_hdim224_bf16_causal_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
                 "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
@@ -224,8 +220,6 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
                 "csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
                 "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
                 "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
-                "csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu",
-                "csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu",
                 "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
                 "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
@@ -240,8 +234,6 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
                 "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
-                "csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu",
-                "csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
@@ -256,8 +248,6 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
                 "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
-                "csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_causal_sm80.cu",
-                "csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_causal_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
                 "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
             ],

+ 89 - 95
tests/test_flash_attn.py

@@ -682,7 +682,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
     # do_o = (g.float() * out.float()).sum(-1)
     # dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64])
     # dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:])
-    if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
+    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
         (dqkv,) = torch.autograd.grad(out, qkv, g)
         (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
         (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
@@ -705,7 +705,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
         if not alibi:
             assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
 
-    if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
+    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
         assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
 
 
@@ -829,7 +829,7 @@ def test_flash_attn_varlen_qkvpacked(
         print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
 
     g = torch.randn_like(out)
-    if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
+    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
         (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)
         dqkv = dqkv_pad_fn(dqkv_unpad)
         (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
@@ -853,7 +853,7 @@ def test_flash_attn_varlen_qkvpacked(
         if not alibi:
             assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
 
-    if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
+    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
         assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
 
 
@@ -866,9 +866,9 @@ def test_flash_attn_varlen_qkvpacked(
 @pytest.mark.parametrize("deterministic", [False, True])
 # @pytest.mark.parametrize("deterministic", [True])
 @pytest.mark.parametrize("alibi", [False, True])
-# @pytest.mark.parametrize("alibi", [True])
+# @pytest.mark.parametrize("alibi", [False])
 @pytest.mark.parametrize("local", [False, True])
-# @pytest.mark.parametrize("local", [True])
+# @pytest.mark.parametrize("local", [False])
 @pytest.mark.parametrize("causal", [False, True])
 # @pytest.mark.parametrize("causal", [True])
 @pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256])
@@ -894,7 +894,7 @@ def test_flash_attn_varlen_qkvpacked(
 )
 # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
 @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
-# @pytest.mark.parametrize("dropout_p", [0.17])
+# @pytest.mark.parametrize("dropout_p", [0.0])
 @pytest.mark.parametrize("softcap", [0.0, 50.0])
 def test_flash_attn_output(
     seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap
@@ -1066,7 +1066,7 @@ def test_flash_attn_output(
 
     g = torch.randn_like(out)
     do_o = (g.float() * out.float()).sum(-1)
-    if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
+    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
         if kvpacked:
             (
                 dq,
@@ -1122,10 +1122,10 @@ def test_flash_attn_output(
         if not alibi:
             assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
 
-    if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)):
-        assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
-        assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
-        assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
+    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
+        assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
+        assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
+        assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
 
 
 @pytest.mark.parametrize("kvpacked", [True, False])
@@ -1382,7 +1382,7 @@ def test_flash_attn_varlen_output(
         print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
 
     g = torch.randn_like(out)
-    if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)):
+    if ((d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90)):
         if kvpacked:
             (
                 dq_unpad,
@@ -1441,7 +1441,7 @@ def test_flash_attn_varlen_output(
         if not alibi:
             assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04)
 
-    if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
+    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
         assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
         assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
         assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
@@ -1519,43 +1519,41 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
 
     g = torch.randn_like(out)
     do_o = (g.float() * out.float()).sum(-1)
-    if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
-        (
-            dq,
-            dk,
-            dv,
-        ) = torch.autograd.grad(out, (q, k, v), g)
-        (
-            dq_ref,
-            dk_ref,
-            dv_ref,
-        ) = torch.autograd.grad(out_ref, (q, k, v), g)
-        (
-            dq_pt,
-            dk_pt,
-            dv_pt,
-        ) = torch.autograd.grad(out_pt, (q, k, v), g)
-        print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
-        print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
-        print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
-        print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
-        print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
-        print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
-        print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
-        print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
-        print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
-        print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
-        print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
-        print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
+    (
+        dq,
+        dk,
+        dv,
+    ) = torch.autograd.grad(out, (q, k, v), g)
+    (
+        dq_ref,
+        dk_ref,
+        dv_ref,
+    ) = torch.autograd.grad(out_ref, (q, k, v), g)
+    (
+        dq_pt,
+        dk_pt,
+        dv_pt,
+    ) = torch.autograd.grad(out_pt, (q, k, v), g)
+    print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
+    print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
+    print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
+    print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
+    print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
+    print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
+    print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
+    print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
+    print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
+    print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
+    print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
+    print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
 
     # Check that FlashAttention's numerical error is at most twice the numerical error
     # of a Pytorch implementation.
     assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
 
-    if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
-        assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
-        assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
-        assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
+    assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
+    assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
+    assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
 
 
 @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@@ -1684,7 +1682,7 @@ def test_flash_attn_varlen_causal(
 
     g = torch.randn_like(out)
     do_o = (g.float() * out.float()).sum(-1)
-    test_backward = (d <= MAX_HEADDIM_SM8x or d > 224 or is_sm80 or is_sm90) and block_table is None
+    test_backward = block_table is None
     if test_backward:
         (
             dq_unpad,
@@ -1815,44 +1813,42 @@ def test_flash_attn_splitkv(
 
     g = torch.randn_like(out)
     do_o = (g.float() * out.float()).sum(-1)
-    if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
-        (
-            dq,
-            dk,
-            dv,
-        ) = torch.autograd.grad(out, (q, k, v), g)
-        (
-            dq_ref,
-            dk_ref,
-            dv_ref,
-        ) = torch.autograd.grad(out_ref, (q, k, v), g)
-        (
-            dq_pt,
-            dk_pt,
-            dv_pt,
-        ) = torch.autograd.grad(out_pt, (q, k, v), g)
-        print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
-        print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
-        print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
-        print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
-        print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
-        print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
-        print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
-        print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
-        print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
-        print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
-        print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
-        print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
+    (
+        dq,
+        dk,
+        dv,
+    ) = torch.autograd.grad(out, (q, k, v), g)
+    (
+        dq_ref,
+        dk_ref,
+        dv_ref,
+    ) = torch.autograd.grad(out_ref, (q, k, v), g)
+    (
+        dq_pt,
+        dk_pt,
+        dv_pt,
+    ) = torch.autograd.grad(out_pt, (q, k, v), g)
+    print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
+    print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
+    print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
+    print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
+    print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
+    print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
+    print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
+    print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
+    print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
+    print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
+    print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
+    print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
 
     # Check that FlashAttention's numerical error is at most twice the numerical error
     # of a Pytorch implementation.
     assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
 
     mult = 2 if not alibi else 8
-    if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
-        assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4
-        assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4
-        assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4
+    assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4
+    assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4
+    assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4
 
 
 # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@@ -2208,7 +2204,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty
     torch.random.manual_seed(42)
     out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)
     g = torch.randn_like(out0)
-    if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
+    if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
         (
             dq0,
             dk0,
@@ -2223,7 +2219,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty
         assert torch.equal(out, out0)
         assert torch.equal(lse, lse0)
 
-        if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
+        if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
             (
                 dq,
                 dk,
@@ -2430,13 +2426,12 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc
     out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True)
 
     g = torch.randn_like(out)
-    if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
-        dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
-        for _ in range(50):
-            dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
-            assert torch.equal(dv, dv0)
-            assert torch.equal(dk, dk0)
-            assert torch.equal(dq, dq0)
+    dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
+    for _ in range(50):
+        dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
+        assert torch.equal(dv, dv0)
+        assert torch.equal(dk, dk0)
+        assert torch.equal(dq, dq0)
 
 
 @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@@ -2518,10 +2513,9 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
     )
 
     g = torch.randn_like(out)
-    if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
-        dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
-        for _ in range(50):
-            dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
-            assert torch.equal(dv, dv0)
-            assert torch.equal(dk, dk0)
-            assert torch.equal(dq, dq0)
+    dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
+    for _ in range(50):
+        dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
+        assert torch.equal(dv, dv0)
+        assert torch.equal(dk, dk0)
+        assert torch.equal(dq, dq0)