Quellcode durchsuchen

Enable paged attention in varlen forward (#831)

* Enable paged attention in varlen forward

* Format + fix padding
Grigory Sizov vor 1 Jahr
Ursprung
Commit
2a15840f09
3 geänderte Dateien mit 106 neuen und 37 gelöschten Zeilen
  1. 37 7
      csrc/flash_attn/flash_api.cpp
  2. 9 1
      flash_attn/flash_attn_interface.py
  3. 60 29
      tests/test_flash_attn.py

+ 37 - 7
csrc/flash_attn/flash_api.cpp

@@ -494,12 +494,13 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
 
 std::vector<at::Tensor>
 mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
-               const at::Tensor &k,  // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
-               const at::Tensor &v,  // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
+               const at::Tensor &k,  // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
+               const at::Tensor &v,  // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
                c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
                const at::Tensor &cu_seqlens_q,  // b+1
                const at::Tensor &cu_seqlens_k,  // b+1
                c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
+               c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
                c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
                int max_seqlen_q,
                const int max_seqlen_k,
@@ -535,6 +536,15 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
     CHECK_DEVICE(cu_seqlens_q);
     CHECK_DEVICE(cu_seqlens_k);
 
+    at::Tensor block_table;
+    const bool paged_KV = block_table_.has_value();
+    if (paged_KV) {
+        block_table = block_table_.value();
+        CHECK_DEVICE(block_table);
+        TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
+        TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
+    }
+
     TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
     TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
     TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
@@ -546,8 +556,12 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
     const int batch_size = cu_seqlens_q.numel() - 1;
     int num_heads = sizes[1];
     const int head_size_og = sizes[2];
-    const int total_k = k.size(0);
-    const int num_heads_k = k.size(1);
+    const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
+
+    const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
+    const int num_blocks = !paged_KV ? 0 : k.size(0);
+    const int page_block_size = !paged_KV ? 1 : k.size(1);
+    TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
 
     if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }  // causal=true is the same as causal=false in this case
     if (is_causal) { window_size_right = 0; }
@@ -575,8 +589,16 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
     if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
 
     CHECK_SHAPE(q, total_q, num_heads, head_size_og);
-    CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
-    CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
+    if (!paged_KV) {
+        const int total_k = k.size(0);
+        CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
+        CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
+    } else {
+        CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
+        CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
+        CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
+    }
+
     CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
     CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
     if (seqused_k.has_value()){
@@ -654,6 +676,14 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
                      window_size_left,
                      window_size_right,
                      seqlenq_ngroups_swapped);
+
+    if (paged_KV) {
+        params.block_table = block_table.data_ptr<int>();
+        params.block_table_batch_stride = block_table.stride(0);
+        params.k_batch_stride = k_padded.stride(0);
+        params.v_batch_stride = v_padded.stride(0);
+    }
+    params.page_block_size = page_block_size;
     if (seqlenq_ngroups_swapped) {
         // Only apply split-k for decoding
         set_params_splitkv(params, batch_size, num_heads,
@@ -682,7 +712,7 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
 
     if (max_seqlen_k > 0) {
         auto stream = at::cuda::getCurrentCUDAStream().stream();
-        run_mha_fwd(params, stream);
+        run_mha_fwd(params, stream, paged_KV);
     } else {
         // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
         out.zero_();

+ 9 - 1
flash_attn/flash_attn_interface.py

@@ -79,6 +79,7 @@ def _flash_attn_varlen_forward(
     window_size,
     alibi_slopes,
     return_softmax,
+    block_table,
 ):
     maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
     q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
@@ -90,6 +91,7 @@ def _flash_attn_varlen_forward(
         cu_seqlens_q,
         cu_seqlens_k,
         None,
+        block_table,
         alibi_slopes,
         max_seqlen_q,
         max_seqlen_k,
@@ -299,6 +301,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
             window_size=window_size,
             alibi_slopes=alibi_slopes,
             return_softmax=return_softmax and dropout_p > 0,
+            block_table=None,
         )
         ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
         ctx.dropout_p = dropout_p
@@ -440,6 +443,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
             window_size=window_size,
             alibi_slopes=alibi_slopes,
             return_softmax=return_softmax and dropout_p > 0,
+            block_table=None,
         )
         ctx.save_for_backward(
             q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
@@ -570,6 +574,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
         alibi_slopes,
         deterministic,
         return_softmax,
+        block_table,
     ):
         if softmax_scale is None:
             softmax_scale = q.shape[-1] ** (-0.5)
@@ -587,6 +592,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
             window_size=window_size,
             alibi_slopes=alibi_slopes,
             return_softmax=return_softmax and dropout_p > 0,
+            block_table=block_table,
         )
         ctx.save_for_backward(
             q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
@@ -630,7 +636,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
         dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
         dk = dk[..., : dout.shape[-1]]
         dv = dv[..., : dout.shape[-1]]
-        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None
+        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
 
 
 def flash_attn_qkvpacked_func(
@@ -1001,6 +1007,7 @@ def flash_attn_varlen_func(
     alibi_slopes=None,
     deterministic=False,
     return_attn_probs=False,
+    block_table=None,
 ):
     """dropout_p should be set to 0.0 during evaluation
     Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
@@ -1071,6 +1078,7 @@ def flash_attn_varlen_func(
         alibi_slopes,
         deterministic,
         return_attn_probs,
+        block_table,
     )
 
 

+ 60 - 29
tests/test_flash_attn.py

@@ -1542,8 +1542,12 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
         (1023, 1024),
     ],
 )
+# TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged
+@pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
 # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
-def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
+def test_flash_attn_varlen_causal(
+    seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype
+):
     if (
         max(seqlen_q, seqlen_k) >= 2048
         and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
@@ -1559,8 +1563,19 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
     nheads = 9
     window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
     q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
-    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
-    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
+
+    if paged_kv_block_size is None:
+        k = torch.randn(
+            batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
+        )
+        v = torch.randn(
+            batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
+        )
+        block_table = None
+    else:
+        k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache(
+            seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype
+        )
     query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
     key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
     (
@@ -1580,8 +1595,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
     ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
     out_unpad = flash_attn_varlen_func(
         q_unpad,
-        k_unpad,
-        v_unpad,
+        k_unpad if paged_kv_block_size is None else k_cache_paged,
+        v_unpad if paged_kv_block_size is None else v_cache_paged,
         cu_seqlens_q,
         cu_seqlens_k,
         max_seqlen_q,
@@ -1589,6 +1604,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
         0.0,
         causal=causal,
         window_size=window_size,
+        block_table=block_table,
     )
     out = output_pad_fn(out_unpad)
     out_ref, attn_ref = attention_ref(
@@ -1625,7 +1641,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
 
     g = torch.randn_like(out)
     do_o = (g.float() * out.float()).sum(-1)
-    if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
+    test_backward = (d <= MAX_HEADDIM_SM8x or d > 224 or is_sm80 or is_sm90) and block_table is None
+    if test_backward:
         (
             dq_unpad,
             dk_unpad,
@@ -1661,7 +1678,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
     # of a Pytorch implementation.
     assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
 
-    if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
+    if test_backward:
         assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
         assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
         assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
@@ -1888,29 +1905,16 @@ def test_flash_attn_kvcache(
         v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
         block_table = None
     else:
-        num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3
-        k_cache_paged = torch.randn(
-            num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
-        )
-        v_cache_paged = torch.randn(
-            num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
-        )
-        block_table = rearrange(
-            torch.randperm(num_blocks, dtype=torch.int32, device=device),
-            "(b nblocks) -> b nblocks",
-            b=batch_size,
+        (
+            k_cache,
+            v_cache,
+            block_table,
+            k_cache_paged,
+            v_cache_paged,
+            num_blocks,
+        ) = _generate_block_kvcache(
+            seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype
         )
-        k_cache = rearrange(
-            # pytorch 1.12 doesn't have indexing with int32
-            k_cache_paged[block_table.to(dtype=torch.long).flatten()],
-            "(b nblocks) block_size ... -> b (nblocks block_size) ...",
-            b=batch_size,
-        )[:, :seqlen_k]
-        v_cache = rearrange(
-            v_cache_paged[block_table.to(dtype=torch.long).flatten()],
-            "(b nblocks) block_size ... -> b (nblocks block_size) ...",
-            b=batch_size,
-        )[:, :seqlen_k]
     cache_seqlens = torch.randint(
         0 if new_kv else 1,
         # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
@@ -2073,6 +2077,33 @@ def test_flash_attn_kvcache(
     assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5
 
 
+def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype):
+    num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3
+    k_cache_paged = torch.randn(
+        num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
+    )
+    v_cache_paged = torch.randn(
+        num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
+    )
+    block_table = rearrange(
+        torch.randperm(num_blocks, dtype=torch.int32, device=device),
+        "(b nblocks) -> b nblocks",
+        b=batch_size,
+    )
+    k_cache = rearrange(
+        # pytorch 1.12 doesn't have indexing with int32
+        k_cache_paged[block_table.to(dtype=torch.long).flatten()],
+        "(b nblocks) block_size ... -> b (nblocks block_size) ...",
+        b=batch_size,
+    )[:, :seqlen_k]
+    v_cache = rearrange(
+        v_cache_paged[block_table.to(dtype=torch.long).flatten()],
+        "(b nblocks) block_size ... -> b (nblocks block_size) ...",
+        b=batch_size,
+    )[:, :seqlen_k]
+    return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks
+
+
 # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
 @pytest.mark.parametrize("dtype", [torch.float16])
 @pytest.mark.parametrize("causal", [False, True])