Browse Source

[Gen] Accept cache_batch_idx to index into the KV cache

Tri Dao 1 năm trước cách đây
mục cha
commit
e279bf8ed9

+ 15 - 6
csrc/flash_attn/flash_api.cpp

@@ -1037,13 +1037,14 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
 
 std::vector<at::Tensor>
 mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_heads x head_size
-                const at::Tensor &kcache,            // batch_size x seqlen_k x num_heads_k x head_size
-                const at::Tensor &vcache,            // batch_size x seqlen_k x num_heads_k x head_size
+                const at::Tensor &kcache,            // batch_size_c x seqlen_k x num_heads_k x head_size
+                const at::Tensor &vcache,            // batch_size_c x seqlen_k x num_heads_k x head_size
                 c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
                 c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
                 c10::optional<const at::Tensor> &seqlens_k_, // batch_size
                 c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
                 c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
+                c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
                 c10::optional<at::Tensor> &out_,             // batch_size x seqlen_q x num_heads x head_size
                 const float softmax_scale,
                 bool is_causal,
@@ -1084,6 +1085,7 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
     const int head_size_og = sizes[3];
     const int seqlen_k = kcache.size(1);
     const int num_heads_k = kcache.size(2);
+    const int batch_size_c = kcache.size(0);
     TORCH_CHECK(batch_size > 0, "batch size must be postive");
     TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
@@ -1102,8 +1104,8 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
     }
 
     CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
-    CHECK_SHAPE(kcache, batch_size, seqlen_k, num_heads_k, head_size_og);
-    CHECK_SHAPE(vcache, batch_size, seqlen_k, num_heads_k, head_size_og);
+    CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
+    CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
 
     at::Tensor q_padded, kcache_padded, vcache_padded;
     if (head_size_og % 8 != 0) {
@@ -1229,6 +1231,13 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
         params.rotary_dim = 0;
     }
 
+    if (cache_batch_idx_.has_value()) {
+        auto cache_batch_idx = cache_batch_idx_.value();
+        CHECK_DEVICE(cache_batch_idx);
+        CHECK_CONTIGUOUS(cache_batch_idx);
+        TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32");
+        params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
+    }
     // This needs to match with run_mha_fwd_splitkv_dispatch
     const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
     const int num_n_blocks = (seqlen_k + block_n - 1) / block_n;
@@ -1248,8 +1257,8 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
     }
 
     auto stream = at::cuda::getCurrentCUDAStream().stream();
-    // Only split kernel supports appending to KV cache
-    run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value());
+    // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx
+    run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value());
 
     if (head_size_og % 8 != 0) {
         out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});

+ 3 - 0
csrc/flash_attn/src/flash.h

@@ -95,6 +95,9 @@ struct Flash_fwd_params : public Qkv_params {
     void * __restrict__ rotary_cos_ptr;
     void * __restrict__ rotary_sin_ptr;
 
+    // The indices to index into the KV cache.
+    int *__restrict__ cache_batch_idx;
+
     // The dropout probability (probability of keeping an activation).
     float p_dropout;
     // uint32_t p_dropout_in_uint;

+ 3 - 2
csrc/flash_attn/src/flash_fwd_kernel.h

@@ -668,9 +668,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
     const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
         + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
     // We move K and V to the last block.
-    const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
+    const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
+    const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)
         + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
-    const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
+    const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
         + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
 
     Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),

+ 10 - 2
flash_attn/flash_attn_interface.py

@@ -928,6 +928,7 @@ def flash_attn_with_kvcache(
     rotary_cos=None,
     rotary_sin=None,
     cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
+    cache_batch_idx: Optional[torch.Tensor] = None,
     softmax_scale=None,
     causal=False,
     window_size=(-1, -1),  # -1 means infinite context window
@@ -978,8 +979,8 @@ def flash_attn_with_kvcache(
 
     Arguments:
         q: (batch_size, seqlen, nheads, headdim)
-        k_cache: (batch_size, seqlen_cache, nheads_k, headdim)
-        v_cache: (batch_size, seqlen_cache, nheads_k, headdim)
+        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
+        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
         k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
             k with k_cache, starting at the indices specified by cache_seqlens.
         v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
@@ -988,6 +989,10 @@ def flash_attn_with_kvcache(
         rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
         cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
             KV cache.
+        cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
+            If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
+            If the indices are not distinct, and k and v are provided, the values updated in the cache
+                 might come from any of the duplicate indices.
         softmax_scale: float. The scaling of QK^T before applying softmax.
             Default to 1 / sqrt(headdim).
         causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
@@ -1014,6 +1019,8 @@ def flash_attn_with_kvcache(
         cache_seqlens = torch.full(
             (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
         )
+        cache_seqlens = maybe_contiguous(cache_seqlens)
+    cache_batch_idx = maybe_contiguous(cache_batch_idx)
     out, softmax_lse = flash_attn_cuda.fwd_kvcache(
         q,
         k_cache,
@@ -1023,6 +1030,7 @@ def flash_attn_with_kvcache(
         cache_seqlens,
         rotary_cos,
         rotary_sin,
+        cache_batch_idx,
         None,
         softmax_scale,
         causal,

+ 18 - 7
tests/test_flash_attn.py

@@ -1668,7 +1668,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
 @pytest.mark.parametrize("new_kv", [False, True])
 # @pytest.mark.parametrize("new_kv", [True])
 @pytest.mark.parametrize("local", [False, True])
-# @pytest.mark.parametrize("local", [True])
+# @pytest.mark.parametrize("local", [False])
 @pytest.mark.parametrize("causal", [False, True])
 # @pytest.mark.parametrize("causal", [True])
 @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
@@ -1677,6 +1677,8 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
 # @pytest.mark.parametrize("rotary_interleaved", [False])
 @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
 # @pytest.mark.parametrize("rotary_fraction", [0.0])
+@pytest.mark.parametrize("has_batch_idx", [False, True])
+# @pytest.mark.parametrize("has_batch_idx", [True])
 @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
 # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
 # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
@@ -1703,6 +1705,7 @@ def test_flash_attn_kvcache(
     seqlen_q,
     seqlen_k,
     d,
+    has_batch_idx,
     rotary_fraction,
     rotary_interleaved,
     seqlen_new_eq_seqlen_q,
@@ -1721,6 +1724,7 @@ def test_flash_attn_kvcache(
     # set seed
     torch.random.manual_seed(0)
     batch_size = 2
+    batch_size_cache = batch_size if not has_batch_idx else batch_size * 2
     nheads = 6
     # rotary_dim must be a multiple of 16, and must be <= d
     rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
@@ -1734,8 +1738,8 @@ def test_flash_attn_kvcache(
         v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
     else:
         k, v = None, None
-    k_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
-    v_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
+    k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
+    v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
     cache_seqlens = torch.randint(
         0,
         # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
@@ -1746,6 +1750,10 @@ def test_flash_attn_kvcache(
         dtype=torch.int32,
         device=device,
     )
+    if has_batch_idx:
+        cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[:batch_size]
+    else:
+        cache_batch_idx = None
     # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
     if rotary_dim > 0:
         angle = torch.rand(seqlen_k, rotary_dim // 2, device=device) * 2 * math.pi
@@ -1775,8 +1783,8 @@ def test_flash_attn_kvcache(
         cos, sin = None, None
         q_ro, k_ro = q, k
     # k_cache[:, 64:] = -1
-    k_cache_ref = k_cache.clone()
-    v_cache_ref = v_cache.clone()
+    k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone()
+    v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone()
     arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
     cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
     if new_kv:
@@ -1796,6 +1804,7 @@ def test_flash_attn_kvcache(
         cos,
         sin,
         cache_seqlens,
+        cache_batch_idx,
         causal=causal,
         window_size=window_size,
         rotary_interleaved=rotary_interleaved,
@@ -1844,8 +1853,10 @@ def test_flash_attn_kvcache(
     # Check that FlashAttention's numerical error is at most twice the numerical error
     # of a Pytorch implementation.
     if new_kv:
-        assert torch.allclose(k_cache, k_cache_ref, rtol=1e-3, atol=1e-3)
-        assert torch.equal(v_cache, v_cache_ref)
+        k_cache_select = k_cache if not has_batch_idx else k_cache[cache_batch_idx]
+        v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx]
+        assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)
+        assert torch.equal(v_cache_select, v_cache_ref)
     assert (out - out_ref).abs().max().item() <= 3 * (out_pt - out_ref).abs().max().item() + 1e-5