Selaa lähdekoodia

Simplify the implementation of KVcache attn by appending KV first

Tri Dao 1 vuosi sitten
vanhempi
commit
56b7fc6ee0
2 muutettua tiedostoa jossa 119 lisäystä ja 105 poistoa
  1. 51 85
      csrc/flash_attn/src/flash_fwd_kernel.h
  2. 68 20
      tests/test_flash_attn.py

+ 51 - 85
csrc/flash_attn/src/flash_fwd_kernel.h

@@ -657,10 +657,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
     const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
         + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
-    const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
-        + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
-    const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
-        + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
 
     Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
                             Shape<Int<kBlockM>, Int<kHeadDim>>{},
@@ -672,18 +668,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
     Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
                             Shape<Int<kBlockN>, Int<kHeadDim>>{},
                             make_stride(params.v_row_stride, _1{}));
-    // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
-    // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
-    // This maps to accessing the first 64 rows of knew_ptr.
-    Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.knew_ptr)
-                                             + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride),
-                               Shape<Int<kBlockN>, Int<kHeadDim>>{},
-                               make_stride(params.knew_row_stride, _1{}));
-    // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
-    Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.vnew_ptr)
-                                             + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
-                               Shape<Int<kBlockN>, Int<kHeadDim>>{},
-                               make_stride(params.vnew_row_stride, _1{}));
 
     Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
                             typename Kernel_traits::SmemLayoutQ{});
@@ -698,10 +682,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
     Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
     Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
     Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK);  // (KCPY, KCPY_N, KCPY_K)
-    Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew);  // (KCPY, KCPY_N, KCPY_K)
     Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
     Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV);  // (VCPY, VCPY_N, VCPY_K)
-    Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew);  // (VCPY, VCPY_N, VCPY_K)
     Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
 
     typename Kernel_traits::TiledMma tiled_mma;
@@ -762,6 +744,49 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
 
     // Prologue
 
+    if constexpr (Append_KV) {
+        // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
+        // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
+        // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
+        const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
+            + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
+        const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
+            + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
+        // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
+        // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
+        // This maps to accessing the first 64 rows of knew_ptr.
+        Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.knew_ptr)
+                                                + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride),
+                                  Shape<Int<kBlockN>, Int<kHeadDim>>{},
+                                  make_stride(params.knew_row_stride, _1{}));
+        // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
+        Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.vnew_ptr)
+                                                + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
+                                  Shape<Int<kBlockN>, Int<kHeadDim>>{},
+                                  make_stride(params.vnew_row_stride, _1{}));
+        Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew);  // (KCPY, KCPY_N, KCPY_K)
+        Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew);  // (VCPY, VCPY_N, VCPY_K)
+
+        const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
+        for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) {
+            flash::copy_w_min_idx<Is_even_K>(
+                tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
+            );
+            flash::copy_w_min_idx<Is_even_K>(
+                tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
+            );
+            tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+            tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
+            tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
+            tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
+        }
+        __syncthreads();
+        if (n_block_max > n_block_copy_min) {
+            tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride;
+            tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride;
+        }
+    }
+
     Tensor tQrQ = make_fragment_like(tQgQ);
     // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
     flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
@@ -769,10 +794,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
 
     int n_block = n_block_max - 1;
     // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
-    flash::copy_2_sources</*Is_2_sources=*/Append_KV, Is_even_MN, Is_even_K>(
-        gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV,
-        binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
-    );
+    flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
+                                       binfo.actual_seqlen_k - n_block * kBlockN);
     cute::cp_async_fence();
 
     // flash::cp_async_wait<0>();
@@ -800,32 +823,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         flash::cp_async_wait<0>();
         __syncthreads();
 
-        if constexpr (Append_KV) {
-            // if (cute::thread0()) { print(tKgK); }
-            // if (cute::thread0()) { print(tKsK); }
-            // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
-            if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
-                flash::copy_w_min_idx<Is_even_K>(
-                    tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
-                );
-            }
-            // __syncthreads();
-            // if (cute::thread0()) { print(tKgK); }
-            // __syncthreads();
-        }
-
         // Advance gV
         if (masking_step > 0) {
             tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
-            if (Append_KV) { tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); }
-            flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
-                gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN
-            );
+            flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
         } else {
             // Clear the smem tiles to account for predicated off loads
-            flash::copy_2_sources</*Is_2_sources=*/Append_KV, Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
-                gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV,
-                binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
+            flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
+                gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
             );
         }
         cute::cp_async_fence();
@@ -856,26 +861,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }
         // __syncthreads();
 
-        // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("n_block = %d, n_block_min = %d\n", n_block, n_block_min); }
-        if constexpr (Append_KV) {
-            // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
-            if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
-                flash::copy_w_min_idx<Is_even_K>(
-                    tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
-                );
-            }
-        }
-
         if (n_block > n_block_min) {
             // Advance gK
-            // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p\n", tKgKnew.data()); }
             tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
-            if (Append_KV) { tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); }
-            // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p, row_idx_switch = %d\n", tKgKnew.data(), binfo.seqlen_k_cache - (n_block - 1) * kBlockN); }
-            flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
-                gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0,
-                binfo.seqlen_k_cache - (n_block - 1) * kBlockN
-            );
+            flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
             // This cp_async_fence needs to be in the if block, otherwise the synchronization
             // isn't right and we get race conditions.
             cute::cp_async_fence();
@@ -909,20 +898,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         clear(acc_s);
         flash::cp_async_wait<0>();
         __syncthreads();
-        if constexpr (Append_KV) {
-            // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
-            if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
-                flash::copy_w_min_idx<Is_even_K>(
-                    tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
-                );
-            }
-        }
         // Advance gV
         tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
-        if (Append_KV) { tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); }
-        flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
-            gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN
-        );
+        flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
         cute::cp_async_fence();
 
         flash::gemm(
@@ -932,22 +910,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
 
         flash::cp_async_wait<0>();
         __syncthreads();
-        if constexpr (Append_KV) {
-            // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
-            if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
-                flash::copy_w_min_idx<Is_even_K>(
-                    tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
-                );
-            }
-        }
         if (n_block > n_block_min) {
             // Advance gK
             tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
-            if (Append_KV) { tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); }
-            flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
-                gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0,
-                binfo.seqlen_k_cache - (n_block - 1) * kBlockN
-            );
+            flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
             // This cp_async_fence needs to be in the if block, otherwise the synchronization
             // isn't right and we get race conditions.
             cute::cp_async_fence();

+ 68 - 20
tests/test_flash_attn.py

@@ -149,8 +149,9 @@ def generate_qkv(
         )
 
 
-def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None,
-                          device=None):
+def construct_causal_mask(
+    seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None
+):
     row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
     col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
     sk = (
@@ -364,12 +365,18 @@ def convert_flash_attn_S_to_softmax(
         causal_mask = construct_causal_mask(
             seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, S.device
         )
-        causal_mask = F.pad(causal_mask, (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q), value=True)
+        causal_mask = F.pad(
+            causal_mask,
+            (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q),
+            value=True,
+        )
         S_converted.masked_fill_(causal_mask, 0.0)
 
     # Need to zero out things not in attention_mask in case S was initialized with random values
     # and some of those values aren't overwritten.
-    seqlen_q_og = query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded
+    seqlen_q_og = (
+        query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded
+    )
     if query_padding_mask is not None:
         query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og))
         S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
@@ -623,7 +630,14 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
     out = output_pad_fn(out_unpad)
     if dropout_p > 0.0:
         S_dmask_converted = convert_flash_attn_S_to_softmax(
-            S_dmask, seqlen, seqlen, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
+            S_dmask,
+            seqlen,
+            seqlen,
+            key_padding_mask,
+            key_padding_mask,
+            d,
+            dropout_p > 0.0,
+            causal=causal,
         )
         dropout_mask = S_dmask_converted >= 0
         attn_unnorm = S_dmask_converted.abs()
@@ -996,7 +1010,14 @@ def test_flash_attn_varlen_output(
     out = output_pad_fn(out_unpad)
     if dropout_p > 0.0:
         S_dmask_converted = convert_flash_attn_S_to_softmax(
-            S_dmask, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
+            S_dmask,
+            seqlen_q,
+            seqlen_k,
+            query_padding_mask,
+            key_padding_mask,
+            d,
+            dropout_p > 0.0,
+            causal=causal,
         )
         dropout_mask = S_dmask_converted >= 0
         attn_unnorm = S_dmask_converted.abs()
@@ -1466,16 +1487,18 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
         assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 2e-4
         assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 2e-4
 
+
 @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
 # @pytest.mark.parametrize("dtype", [torch.float16])
 @pytest.mark.parametrize("num_splits", [1, 0])
 # @pytest.mark.parametrize("num_splits", [0])
 @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
-# @pytest.mark.parametrize("mha_type", ["mqa"])
+# @pytest.mark.parametrize("mha_type", ["mha"])
 @pytest.mark.parametrize("new_kv", [False, True])
-# @pytest.mark.parametrize("new_kv", [False])
+# @pytest.mark.parametrize("new_kv", [True])
 @pytest.mark.parametrize("causal", [False, True])
-# @pytest.mark.parametrize("causal", [True])
+# @pytest.mark.parametrize("causal", [False])
+@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
 @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
 # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
 # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
@@ -1499,7 +1522,9 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
     ],
 )
 # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
-def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num_splits, dtype):
+def test_flash_attn_kvcache(
+    seqlen_q, seqlen_k, d, seqlen_new_eq_seqlen_q, causal, new_kv, mha_type, num_splits, dtype
+):
     if seqlen_q > seqlen_k and new_kv:
         pytest.skip()
     device = "cuda"
@@ -1510,14 +1535,21 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num
     nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
     assert nheads % nheads_k == 0
     q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
+    seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item()
     if new_kv:
-        k = torch.randn(batch_size, seqlen_q, nheads_k, d, device=device, dtype=dtype)
-        v = torch.randn(batch_size, seqlen_q, nheads_k, d, device=device, dtype=dtype)
+        k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
+        v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
     else:
         k, v = None, None
     k_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
     v_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
-    cache_seqlens = torch.randint(0, (seqlen_k - seqlen_q + 1) if new_kv else (seqlen_k + 1), (batch_size, ), dtype=torch.int32, device=device)
+    cache_seqlens = torch.randint(
+        0,
+        (seqlen_k - seqlen_new + 1) if new_kv else (seqlen_k + 1),
+        (batch_size,),
+        dtype=torch.int32,
+        device=device,
+    )
     # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
     # k_cache[:, 64:] = -1
     k_cache_ref = k_cache.clone()
@@ -1525,12 +1557,16 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num
     arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
     cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
     if new_kv:
-        update_mask = torch.logical_and(cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_q)
+        update_mask = torch.logical_and(
+            cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
+        )
         k_cache_ref[update_mask] = rearrange(k, "b s ... -> (b s) ...")
         v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...")
     k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
     v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
-    out = flash_attn_with_kvcache(q, k_cache, v_cache, k, v, cache_seqlens, causal=causal, num_splits=num_splits)
+    out = flash_attn_with_kvcache(
+        q, k_cache, v_cache, k, v, cache_seqlens, causal=causal, num_splits=num_splits
+    )
     # out = flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal)
     # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal)
     # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
@@ -1539,10 +1575,22 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num
     # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
     # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
     # probs = torch.softmax(qk, dim=-1)
-    key_padding_mask = arange < cache_seqlens_expanded + (seqlen_q if new_kv else 0)
-    out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal)
-    out_pt, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal,
-                              upcast=False, reorder_ops=True)
+    key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
+    out_ref, _ = attention_ref(
+        q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal
+    )
+    out_pt, _ = attention_ref(
+        q,
+        k_cache_rep,
+        v_cache_rep,
+        None,
+        key_padding_mask,
+        0.0,
+        None,
+        causal=causal,
+        upcast=False,
+        reorder_ops=True,
+    )
     print(f"Output max diff: {(out - out_ref).abs().max().item()}")
     print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
     print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
@@ -1583,7 +1631,7 @@ def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num
         (1024, 1024),
     ],
 )
-@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
+@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
 # @pytest.mark.parametrize("dropout_p", [0.0])
 def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype):
     device = "cuda"