Преглед изворни кода

Implement page KV cache

Co-authored-by: ljss <450993438@qq.com>
Tri Dao пре 1 година
родитељ
комит
54e80a3829

+ 37 - 2
README.md

@@ -148,6 +148,7 @@ def flash_attn_with_kvcache(
     rotary_sin=None,
     cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
     cache_batch_idx: Optional[torch.Tensor] = None,
+    block_table: Optional[torch.Tensor] = None,
     softmax_scale=None,
     causal=False,
     window_size=(-1, -1),  # -1 means infinite context window
@@ -160,6 +161,10 @@ def flash_attn_with_kvcache(
     the previous step, and update them with the new keys/values from the current step, and do
     attention with the updated cache, all in 1 kernel.
 
+    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
+    For example, the KV cache could be pre-allocated with the max sequence length, and you can use
+    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
+
     Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
     rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
     If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
@@ -169,12 +174,36 @@ def flash_attn_with_kvcache(
 
     See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
 
+    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
+    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
+    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
+    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
+
+    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
+    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
+        1 1 1 1 0
+        1 1 1 1 1
+    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
+        0 0
+        0 0
+        0 0
+        1 0
+        1 1
+    If the row of the mask is all zero, the output will be zero.
+
+    If window_size != (-1, -1), implements sliding window local attention. Query at position i
+    will only attend to keys between
+    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
+
     Note: Does not support backward pass.
 
     Arguments:
         q: (batch_size, seqlen, nheads, headdim)
-        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
-        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
+        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
+            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
+            page_block_size must be a multiple of 256.
+        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
+            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
         k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
             k with k_cache, starting at the indices specified by cache_seqlens.
         v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
@@ -183,6 +212,7 @@ def flash_attn_with_kvcache(
         rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
         cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
             KV cache.
+        block_table [optional]: (num_blocks, max_num_blocks_per_seq), dtype torch.int32.
         cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
             If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
             If the indices are not distinct, and k and v are provided, the values updated in the cache
@@ -279,6 +309,11 @@ Implement ALiBi (Press et al., 2021). Thanks to Sanghun Cho from Kakao Brain for
 
 Implement deterministic backward pass. Thanks to engineers from [Meituan](www.meituan.com) for this contribution.
 
+### 2.5: Paged KV cache.
+
+Support paged KV cache (i.e., [PagedAttention](https://arxiv.org/abs/2309.06180)).
+Thanks to @beginlner for this contribution.
+
 ## Performance
 
 We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).

+ 42 - 12
csrc/flash_attn/flash_api.cpp

@@ -257,8 +257,9 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n
 }
 
 void set_params_splitkv(Flash_fwd_params &params, const int batch_size,
-    const int num_heads, const int head_size,  const int max_seqlen_k, const int max_seqlen_q,
-    const int head_size_rounded,  float p_dropout, const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) {
+    const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
+    const int head_size_rounded, const float p_dropout,
+    const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) {
 
     // This needs to match with run_mha_fwd_splitkv_dispatch
     const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
@@ -635,8 +636,8 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
     if (seqlenq_ngroups_swapped) {
         // Only apply split-k for decoding
         set_params_splitkv(params, batch_size, num_heads,
-                       head_size, max_seqlen_k, max_seqlen_q,
-                       head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
+                           head_size, max_seqlen_k, max_seqlen_q,
+                           head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
     }
 
     // number of times random will be generated per thread, to offset philox counter in thc random
@@ -1194,14 +1195,15 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
 
 std::vector<at::Tensor>
 mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_heads x head_size
-                const at::Tensor &kcache,            // batch_size_c x seqlen_k x num_heads_k x head_size
-                const at::Tensor &vcache,            // batch_size_c x seqlen_k x num_heads_k x head_size
+                const at::Tensor &kcache,            // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
+                const at::Tensor &vcache,            // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
                 c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
                 c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
                 c10::optional<const at::Tensor> &seqlens_k_, // batch_size
                 c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
                 c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
                 c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
+                c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
                 c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
                 c10::optional<at::Tensor> &out_,             // batch_size x seqlen_q x num_heads x head_size
                 const float softmax_scale,
@@ -1235,15 +1237,30 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
     TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
     TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
 
+    at::Tensor block_table;
+    const bool paged_KV = block_table_.has_value();
+    if (paged_KV) {
+        TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx");
+        block_table = block_table_.value();
+        CHECK_DEVICE(block_table);
+        TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
+        TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
+    }
+
     const auto sizes = q.sizes();
 
     const int batch_size = sizes[0];
     int seqlen_q = sizes[1];
     int num_heads = sizes[2];
     const int head_size_og = sizes[3];
-    const int seqlen_k = kcache.size(1);
+
+    const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
+    const int num_blocks = !paged_KV ? 0 : kcache.size(0);
+    const int page_block_size = !paged_KV ? 1 : kcache.size(1);
+    TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
+    const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
     const int num_heads_k = kcache.size(2);
-    const int batch_size_c = kcache.size(0);
+    const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
     TORCH_CHECK(batch_size > 0, "batch size must be postive");
     TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
@@ -1266,8 +1283,14 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
     if (window_size_right >= seqlen_k) { window_size_right = -1; }
 
     CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
-    CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
-    CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
+    if (!paged_KV) {
+        CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
+        CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
+    } else {
+        CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
+        CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
+        CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
+    }
 
     at::Tensor q_padded, kcache_padded, vcache_padded;
     if (head_size_og % 8 != 0) {
@@ -1406,6 +1429,12 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
                        head_size, seqlen_k, seqlen_q,
                        head_size_rounded, /*dropout*/0.f, num_splits, dprops, opts);
 
+    if (paged_KV) {
+        params.block_table = block_table.data_ptr<int>();
+        params.block_table_batch_stride = block_table.stride(0);
+    }
+    params.page_block_size = page_block_size;
+
     if (alibi_slopes_.has_value()) {
         auto alibi_slopes = alibi_slopes_.value();
         TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
@@ -1419,8 +1448,9 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
     }
 
     auto stream = at::cuda::getCurrentCUDAStream().stream();
-    // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx
-    run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value());
+    // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx,
+    // or paged KV cache
+    run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV);
 
     if (head_size_og % 8 != 0) {
         out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});

+ 5 - 0
csrc/flash_attn/src/flash.h

@@ -101,6 +101,11 @@ struct Flash_fwd_params : public Qkv_params {
     // The indices to index into the KV cache.
     int * __restrict__ cache_batch_idx;
 
+    // Paged KV cache
+    int * __restrict__ block_table;
+    index_t block_table_batch_stride;
+    int page_block_size;
+
     // The dropout probability (probability of keeping an activation).
     float p_dropout;
     // uint32_t p_dropout_in_uint;

+ 66 - 14
csrc/flash_attn/src/flash_fwd_kernel.h

@@ -559,10 +559,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
     // We move K and V to the last block.
     const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
-    const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)
-        + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
-    const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
-        + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
+    const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride;
+    const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size;
+    const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size;
+    const index_t row_offset_k = block_table == nullptr
+        ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)
+          + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride
+        : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
+    const index_t row_offset_v = block_table == nullptr
+        ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
+          + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
+        : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
 
     Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
                             Shape<Int<kBlockM>, Int<kHeadDim>>{},
@@ -695,11 +702,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew);  // (VCPY, VCPY_N, VCPY_K)
 
         const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
+        auto tKgK_data = tKgK.data();
+        auto tVgV_data = tVgV.data();
         for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) {
             flash::copy_w_min_idx<Is_even_K>(
                 tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
             );
-            tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
             tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
             if (params.rotary_dim == 0) {
                 flash::copy_w_min_idx<Is_even_K>(
@@ -725,15 +733,27 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
 
                 }
             }
-            tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
             tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
+            if (block_table == nullptr) {
+                tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
+                tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+            } else {
+                if (n_block > n_block_copy_min) {
+                    const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
+                    const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
+                    const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
+                    const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
+                    const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur];
+                    const int offset_diff = block_table_offset_next - block_table_offset_cur;
+                    tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
+                    tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
+                }
+            }
         }
         // Need this before we can read in K again, so that we'll see the updated K values.
         __syncthreads();
-        if (n_block_max > n_block_copy_min) {
-            tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride;
-            tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride;
-        }
+        tKgK.data() = tKgK_data;
+        tVgV.data() = tVgV_data;
     }
 
     // Read Q from gmem to smem, optionally apply rotary embedding.
@@ -812,7 +832,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
 
         // Advance gV
         if (masking_step > 0) {
-            tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
+            if (block_table == nullptr) {
+                tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
+            } else {
+                const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
+                const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
+                const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
+                const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
+                tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
+            }
             flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
         } else {
             // Clear the smem tiles to account for predicated off loads
@@ -839,7 +867,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
 
         if (n_block > n_block_min) {
             // Advance gK
-            tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+            if (block_table == nullptr) {
+                tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+            } else {
+                const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
+                const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
+                const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
+                const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
+                tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
+            }
             flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
             // This cp_async_fence needs to be in the if block, otherwise the synchronization
             // isn't right and we get race conditions.
@@ -874,7 +910,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         flash::cp_async_wait<0>();
         __syncthreads();
         // Advance gV
-        tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
+        if (block_table == nullptr) {
+            tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
+        } else {
+            const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
+            const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
+            const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
+            const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
+            tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
+        }
         flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
         cute::cp_async_fence();
 
@@ -887,7 +931,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         __syncthreads();
         if (n_block > n_block_min) {
             // Advance gK
-            tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+            if (block_table == nullptr) {
+                tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+            } else {
+                const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
+                const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
+                const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
+                const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
+                tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
+            }
             flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
             // This cp_async_fence needs to be in the if block, otherwise the synchronization
             // isn't right and we get race conditions.

+ 9 - 2
flash_attn/flash_attn_interface.py

@@ -1084,6 +1084,7 @@ def flash_attn_with_kvcache(
     rotary_sin=None,
     cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
     cache_batch_idx: Optional[torch.Tensor] = None,
+    block_table: Optional[torch.Tensor] = None,
     softmax_scale=None,
     causal=False,
     window_size=(-1, -1),  # -1 means infinite context window
@@ -1135,8 +1136,11 @@ def flash_attn_with_kvcache(
 
     Arguments:
         q: (batch_size, seqlen, nheads, headdim)
-        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
-        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim)
+        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
+            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
+            page_block_size must be a multiple of 256.
+        v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
+            or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
         k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
             k with k_cache, starting at the indices specified by cache_seqlens.
         v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
@@ -1145,6 +1149,7 @@ def flash_attn_with_kvcache(
         rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
         cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
             KV cache.
+        block_table [optional]: (num_blocks, max_num_blocks_per_seq), dtype torch.int32.
         cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
             If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
             If the indices are not distinct, and k and v are provided, the values updated in the cache
@@ -1180,6 +1185,7 @@ def flash_attn_with_kvcache(
         )
         cache_seqlens = maybe_contiguous(cache_seqlens)
     cache_batch_idx = maybe_contiguous(cache_batch_idx)
+    block_table = maybe_contiguous(block_table)
     out, softmax_lse = flash_attn_cuda.fwd_kvcache(
         q,
         k_cache,
@@ -1190,6 +1196,7 @@ def flash_attn_with_kvcache(
         rotary_cos,
         rotary_sin,
         cache_batch_idx,
+        block_table,
         alibi_slopes,
         None,
         softmax_scale,

+ 72 - 19
tests/test_flash_attn.py

@@ -708,7 +708,9 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
 # @pytest.mark.parametrize('seqlen', [128])
 @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
 # @pytest.mark.parametrize('dropout_p', [0.0])
-def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
+def test_flash_attn_varlen_qkvpacked(
+    seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype
+):
     if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
         pytest.skip()  # Reference implementation OOM
     device = "cuda"
@@ -1698,7 +1700,9 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
     ],
 )
 # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
-def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype):
+def test_flash_attn_splitkv(
+    seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype
+):
     if swap_sq_sk:
         seqlen_q, seqlen_k = seqlen_k, seqlen_q
     device = "cuda"
@@ -1800,7 +1804,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, al
 @pytest.mark.parametrize("new_kv", [False, True])
 # @pytest.mark.parametrize("new_kv", [False])
 @pytest.mark.parametrize("alibi", [False, True])
-# @pytest.mark.parametrize("alibi", [True])
+# @pytest.mark.parametrize("alibi", [False])
 @pytest.mark.parametrize("local", [False, True])
 # @pytest.mark.parametrize("local", [False])
 @pytest.mark.parametrize("causal", [False, True])
@@ -1811,10 +1815,12 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, al
 # @pytest.mark.parametrize("rotary_interleaved", [False])
 @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
 # @pytest.mark.parametrize("rotary_fraction", [0.0])
+# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
+@pytest.mark.parametrize("paged_kv_block_size", [256, 512])
 @pytest.mark.parametrize("has_batch_idx", [False, True])
 # @pytest.mark.parametrize("has_batch_idx", [False])
-@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
-# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
+@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
+# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
 # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
 # @pytest.mark.parametrize('d', [56, 80])
 # @pytest.mark.parametrize("d", [128])
@@ -1840,6 +1846,7 @@ def test_flash_attn_kvcache(
     seqlen_k,
     d,
     has_batch_idx,
+    paged_kv_block_size,
     rotary_fraction,
     rotary_interleaved,
     seqlen_new_eq_seqlen_q,
@@ -1855,6 +1862,8 @@ def test_flash_attn_kvcache(
         pytest.skip()
     if not new_kv and rotary_fraction > 0.0:
         pytest.skip()
+    if has_batch_idx and paged_kv_block_size is not None:
+        pytest.skip()
     device = "cuda"
     # set seed
     torch.random.manual_seed(0)
@@ -1873,10 +1882,35 @@ def test_flash_attn_kvcache(
         v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
     else:
         k, v = None, None
-    k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
-    v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
+    if paged_kv_block_size is None:
+        k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
+        v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
+        block_table = None
+    else:
+        num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3
+        k_cache_paged = torch.randn(
+            num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
+        )
+        v_cache_paged = torch.randn(
+            num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
+        )
+        block_table = rearrange(
+            torch.randperm(num_blocks, dtype=torch.int32, device=device),
+            "(b nblocks) -> b nblocks",
+            b=batch_size,
+        )
+        k_cache = rearrange(
+            k_cache_paged[block_table.flatten()],
+            "(b nblocks) block_size ... -> b (nblocks block_size) ...",
+            b=batch_size,
+        )[:, :seqlen_k]
+        v_cache = rearrange(
+            v_cache_paged[block_table.flatten()],
+            "(b nblocks) block_size ... -> b (nblocks block_size) ...",
+            b=batch_size,
+        )[:, :seqlen_k]
     cache_seqlens = torch.randint(
-        0,
+        0 if new_kv else 1,
         # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
         (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)
         if new_kv
@@ -1903,7 +1937,15 @@ def test_flash_attn_kvcache(
         alibi_slopes, attn_bias = None, None
     # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
     if rotary_dim > 0:
-        angle = torch.rand(seqlen_k, rotary_dim // 2, device=device) * 2 * math.pi
+        angle = (
+            torch.rand(
+                seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size,
+                rotary_dim // 2,
+                device=device,
+            )
+            * 2
+            * math.pi
+        )
         cos = torch.cos(angle).to(dtype=dtype)
         sin = torch.sin(angle).to(dtype=dtype)
         if causal or local:
@@ -1942,14 +1984,15 @@ def test_flash_attn_kvcache(
     v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
     out = flash_attn_with_kvcache(
         q,
-        k_cache,
-        v_cache,
+        k_cache if paged_kv_block_size is None else k_cache_paged,
+        v_cache if paged_kv_block_size is None else v_cache_paged,
         k,
         v,
-        cos,
-        sin,
-        cache_seqlens,
-        cache_batch_idx,
+        rotary_cos=cos,
+        rotary_sin=sin,
+        cache_seqlens=cache_seqlens,
+        cache_batch_idx=cache_batch_idx,
+        block_table=block_table,
         causal=causal,
         window_size=window_size,
         rotary_interleaved=rotary_interleaved,
@@ -2000,8 +2043,20 @@ def test_flash_attn_kvcache(
     # Check that FlashAttention's numerical error is at most twice the numerical error
     # of a Pytorch implementation.
     if new_kv:
-        k_cache_select = k_cache if not has_batch_idx else k_cache[cache_batch_idx]
-        v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx]
+        if paged_kv_block_size is None:
+            k_cache_select = k_cache if not has_batch_idx else k_cache[cache_batch_idx]
+            v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx]
+        else:
+            k_cache_select = rearrange(
+                k_cache_paged[block_table.flatten()],
+                "(b nblocks) block_size ... -> b (nblocks block_size) ...",
+                b=batch_size,
+            )[:, :seqlen_k]
+            v_cache_select = rearrange(
+                v_cache_paged[block_table.flatten()],
+                "(b nblocks) block_size ... -> b (nblocks block_size) ...",
+                b=batch_size,
+            )[:, :seqlen_k]
         assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)
         assert torch.equal(v_cache_select, v_cache_ref)
     mult = 3 if not alibi else 5
@@ -2280,8 +2335,6 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc
             assert torch.equal(dq, dq0)
 
 
-
-
 @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
 # @pytest.mark.parametrize("dtype", [torch.bfloat16])
 @pytest.mark.parametrize("local", [False, True])