Jelajahi Sumber

Enable headdim 256 backward on consumer GPUs (Ampere, Ada)

Tri Dao 1 tahun lalu
induk
melakukan
2406f28805

+ 1 - 1
README.md

@@ -70,7 +70,7 @@ FlashAttention-2 currently supports:
    GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
    GPUs for now.
 2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
-3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800.
+3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
 
 
 ## How to use FlashAttention

+ 4 - 4
csrc/flash_attn/flash_api.cpp

@@ -783,8 +783,8 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
     TORCH_CHECK(batch_size > 0, "batch size must be positive");
     TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
     TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
-    if (head_size > 192) {
-        TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800");
+    if (head_size > 192 && (head_size <= 224 || is_dropout)) {
+        TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
     }
     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
 
@@ -1020,8 +1020,8 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
     TORCH_CHECK(batch_size > 0, "batch size must be positive");
     TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
     TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
-    if (head_size > 192) {
-        TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800");
+    if (head_size > 192 && (head_size <= 224 || is_dropout)) {
+        TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
     }
     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
 

+ 1 - 1
csrc/flash_attn/src/flash_bwd_kernel.h

@@ -521,7 +521,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
         // if (cute::thread(32, 0)) { print(scores); }
         // Compute the exponential value.
         flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
-        if (Is_dropout) {
+        if constexpr (Is_dropout) {
             int warp_id = tidx / 32;
             int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
             // Need col to be multiples of 32, since we're doing dropout with block of 16 x 32

+ 5 - 1
csrc/flash_attn/src/flash_bwd_launch_template.h

@@ -296,8 +296,12 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
     DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
         if (max_smem_per_block >= 176 * 1024) {  // H100
             run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
-        } else {  // A100, we don't do double buffering to save smem
+        } else if (max_smem_per_block >= 144 * 1024) {  // A100, we don't do double buffering to save smem
             run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream);
+        } else { // sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering.
+            if constexpr (!Is_dropout) {
+                run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, false>(params, stream);
+            }
         }
     });
 }

+ 4 - 2
csrc/flash_attn/src/kernel_traits.h

@@ -231,9 +231,11 @@ struct Flash_bwd_kernel_traits : public Base {
     // TODO: generalize to other values of kBlockN
     // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
     // static constexpr int kPBlockN = kBlockN;
-    static_assert(kBlockN >= 64);
+    // Temporarily disabling this for hdim 256 on sm86 and sm89
+    // static_assert(kBlockN >= 64);
+    static_assert(kBlockN >= 32);
     // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.
-    static constexpr int kPBlockN = 64;
+    static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32;
     static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);
     // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);
     static constexpr int kSwizzlePdS = 3;

+ 38 - 28
tests/test_flash_attn.py

@@ -664,7 +664,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
     # do_o = (g.float() * out.float()).sum(-1)
     # dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64])
     # dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:])
-    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
+    if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
         (dqkv,) = torch.autograd.grad(out, qkv, g)
         (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
         (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
@@ -687,7 +687,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
         if not alibi:
             assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
 
-    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
+    if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
         assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
 
 
@@ -811,7 +811,7 @@ def test_flash_attn_varlen_qkvpacked(
         print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
 
     g = torch.randn_like(out)
-    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
+    if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
         (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)
         dqkv = dqkv_pad_fn(dqkv_unpad)
         (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
@@ -835,7 +835,7 @@ def test_flash_attn_varlen_qkvpacked(
         if not alibi:
             assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
 
-    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
+    if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
         assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
 
 
@@ -1036,7 +1036,7 @@ def test_flash_attn_output(
 
     g = torch.randn_like(out)
     do_o = (g.float() * out.float()).sum(-1)
-    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
+    if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
         if kvpacked:
             (
                 dq,
@@ -1092,7 +1092,7 @@ def test_flash_attn_output(
         if not alibi:
             assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
 
-    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
+    if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
         assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
         assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
         assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
@@ -1339,7 +1339,7 @@ def test_flash_attn_varlen_output(
         print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
 
     g = torch.randn_like(out)
-    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
+    if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
         if kvpacked:
             (
                 dq_unpad,
@@ -1398,7 +1398,7 @@ def test_flash_attn_varlen_output(
         if not alibi:
             assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
 
-    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
+    if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
         assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
         assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
         assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
@@ -1476,7 +1476,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
 
     g = torch.randn_like(out)
     do_o = (g.float() * out.float()).sum(-1)
-    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
+    if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
         (
             dq,
             dk,
@@ -1509,7 +1509,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
     # of a Pytorch implementation.
     assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
 
-    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
+    if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
         assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
         assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
         assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
@@ -1625,7 +1625,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
 
     g = torch.randn_like(out)
     do_o = (g.float() * out.float()).sum(-1)
-    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
+    if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
         (
             dq_unpad,
             dk_unpad,
@@ -1661,7 +1661,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
     # of a Pytorch implementation.
     assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
 
-    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
+    if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
         assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
         assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
         assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
@@ -1755,7 +1755,7 @@ def test_flash_attn_splitkv(
 
     g = torch.randn_like(out)
     do_o = (g.float() * out.float()).sum(-1)
-    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
+    if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
         (
             dq,
             dk,
@@ -1789,7 +1789,7 @@ def test_flash_attn_splitkv(
     assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
 
     mult = 2 if not alibi else 8
-    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
+    if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
         assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4
         assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4
         assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4
@@ -1815,8 +1815,9 @@ def test_flash_attn_splitkv(
 # @pytest.mark.parametrize("rotary_interleaved", [False])
 @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
 # @pytest.mark.parametrize("rotary_fraction", [0.0])
-# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
-@pytest.mark.parametrize("paged_kv_block_size", [256, 512])
+@pytest.mark.parametrize("paged_kv_block_size", [None, 256])
+# @pytest.mark.parametrize("paged_kv_block_size", [256, 512])
+# @pytest.mark.parametrize("paged_kv_block_size", [256])
 @pytest.mark.parametrize("has_batch_idx", [False, True])
 # @pytest.mark.parametrize("has_batch_idx", [False])
 @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
@@ -1900,12 +1901,13 @@ def test_flash_attn_kvcache(
             b=batch_size,
         )
         k_cache = rearrange(
-            k_cache_paged[block_table.flatten()],
+            # pytorch 1.12 doesn't have indexing with int32
+            k_cache_paged[block_table.to(dtype=torch.long).flatten()],
             "(b nblocks) block_size ... -> b (nblocks block_size) ...",
             b=batch_size,
         )[:, :seqlen_k]
         v_cache = rearrange(
-            v_cache_paged[block_table.flatten()],
+            v_cache_paged[block_table.to(dtype=torch.long).flatten()],
             "(b nblocks) block_size ... -> b (nblocks block_size) ...",
             b=batch_size,
         )[:, :seqlen_k]
@@ -1972,8 +1974,12 @@ def test_flash_attn_kvcache(
         cos, sin = None, None
         q_ro, k_ro = q, k
     # k_cache[:, 64:] = -1
-    k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone()
-    v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone()
+    k_cache_ref = (
+        k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
+    ).clone()
+    v_cache_ref = (
+        v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
+    ).clone()
     if new_kv:
         update_mask = torch.logical_and(
             cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
@@ -2044,16 +2050,20 @@ def test_flash_attn_kvcache(
     # of a Pytorch implementation.
     if new_kv:
         if paged_kv_block_size is None:
-            k_cache_select = k_cache if not has_batch_idx else k_cache[cache_batch_idx]
-            v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx]
+            k_cache_select = (
+                k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
+            )
+            v_cache_select = (
+                v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
+            )
         else:
             k_cache_select = rearrange(
-                k_cache_paged[block_table.flatten()],
+                k_cache_paged[block_table.to(dtype=torch.long).flatten()],
                 "(b nblocks) block_size ... -> b (nblocks block_size) ...",
                 b=batch_size,
             )[:, :seqlen_k]
             v_cache_select = rearrange(
-                v_cache_paged[block_table.flatten()],
+                v_cache_paged[block_table.to(dtype=torch.long).flatten()],
                 "(b nblocks) block_size ... -> b (nblocks block_size) ...",
                 b=batch_size,
             )[:, :seqlen_k]
@@ -2104,7 +2114,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty
     torch.random.manual_seed(42)
     out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)
     g = torch.randn_like(out0)
-    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
+    if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
         (
             dq0,
             dk0,
@@ -2119,7 +2129,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty
         assert torch.equal(out, out0)
         assert torch.equal(lse, lse0)
 
-        if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
+        if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
             (
                 dq,
                 dk,
@@ -2326,7 +2336,7 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc
     out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True)
 
     g = torch.randn_like(out)
-    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
+    if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
         dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
         for _ in range(50):
             dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
@@ -2414,7 +2424,7 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
     )
 
     g = torch.randn_like(out)
-    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
+    if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
         dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
         for _ in range(50):
             dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)