Browse Source

Implement flash_attn_with_kvcache

Tri Dao 1 year ago
parent
commit
37c6e05406

+ 192 - 3
csrc/flash_attn/flash_api.cpp

@@ -102,6 +102,7 @@ void set_params_fprop(Flash_fwd_params &params,
     TORCH_CHECK(p_dropout < 1.f);
 
     params.is_causal = is_causal;
+    params.is_seqlens_k_cumulative = true;
 }
 
 void set_params_dgrad(Flash_bwd_params &params,
@@ -175,10 +176,10 @@ void set_params_dgrad(Flash_bwd_params &params,
     params.dsoftmax_sum = dsoftmax_sum_d;
 }
 
-void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
+void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
     FP16_SWITCH(!params.is_bf16, [&] {
         FWD_HEADDIM_SWITCH(params.d, [&] {
-            if (params.num_splits <= 1) {  // If we don't set it num_splits == 0
+            if (params.num_splits <= 1 && !force_split_kernel) {  // If we don't set it num_splits == 0
                 run_mha_fwd_<elem_type, kHeadDim>(params, stream);
             } else {
                 run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
@@ -350,7 +351,7 @@ mha_fwd(const at::Tensor &q,         // batch_size x seqlen_q x num_heads x head
     const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
     params.num_splits = 1;
     if (p_dropout == 0.0f) {  // SplitKV is not implemented for dropout
-        params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 64);
+        params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
         if (params.num_splits > 1) {
             at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
             at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
@@ -990,10 +991,198 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
     return { dq, dk, dv, softmax_d };
 }
 
+std::vector<at::Tensor>
+mha_fwd_kvcache(const at::Tensor &q,                 // batch_size x seqlen_q x num_heads x head_size
+                const at::Tensor &kcache,            // batch_size x seqlen_k x num_heads_k x head_size
+                const at::Tensor &vcache,            // batch_size x seqlen_k x num_heads_k x head_size
+                c10::optional<const at::Tensor> &k_, // batch_size x seqlen_q x num_heads_k x head_size
+                c10::optional<const at::Tensor> &v_, // batch_size x seqlen_q x num_heads_k x head_size
+                c10::optional<const at::Tensor> &seqlens_k_, // batch_size
+                c10::optional<at::Tensor> &out_,             // batch_size x seqlen_q x num_heads x head_size
+                const float softmax_scale,
+                const bool is_causal,
+                int num_splits
+                ) {
+
+    auto dprops = at::cuda::getCurrentDeviceProperties();
+    // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
+    bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
+    bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
+    TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
+    // We will support Turing in the near future
+    // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
+
+    auto q_dtype = q.dtype();
+    TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
+                "FlashAttention only support fp16 and bf16 data type");
+    if (q_dtype == torch::kBFloat16) {
+        TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
+    }
+    TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
+    TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype");
+
+    TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
+    TORCH_CHECK(kcache.is_cuda(), "Input tensor must be on CUDA device");
+    TORCH_CHECK(vcache.is_cuda(), "Input tensor must be on CUDA device");
+
+    TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
+    TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
+    TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
+
+    const auto sizes = q.sizes();
+
+    const int batch_size = sizes[0];
+    const int seqlen_q = sizes[1];
+    const int num_heads = sizes[2];
+    const int head_size_og = sizes[3];
+    const int seqlen_k = kcache.size(1);
+    const int num_heads_k = kcache.size(2);
+    TORCH_CHECK(batch_size > 0, "batch size must be postive");
+    TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
+    TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
+
+    CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
+    CHECK_SHAPE(kcache, batch_size, seqlen_k, num_heads_k, head_size_og);
+    CHECK_SHAPE(vcache, batch_size, seqlen_k, num_heads_k, head_size_og);
+
+    at::Tensor q_padded, kcache_padded, vcache_padded;
+    if (head_size_og % 8 != 0) {
+        q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+        kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+        vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+    } else {
+        q_padded = q;
+        kcache_padded = kcache;
+        vcache_padded = vcache;
+    }
+
+    at::Tensor out;
+    if (out_.has_value()) {
+        out = out_.value();
+        TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
+        TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device");
+        TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
+        CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
+        if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
+    } else {
+        out = torch::empty_like(q_padded);
+    }
+
+    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+    const int head_size = round_multiple(head_size_og, 8);
+    const int head_size_rounded = round_multiple(head_size, 32);
+    const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
+    const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
+
+    // Otherwise the kernel will be launched from cuda:0 device
+    // Cast to char to avoid compiler warning about narrowing
+    at::cuda::CUDAGuard device_guard{(char)q.get_device()};
+
+    auto opts = q.options();
+
+    auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
+
+    Flash_fwd_params params;
+    set_params_fprop(params,
+                     batch_size,
+                     seqlen_q, seqlen_k,
+                     seqlen_q_rounded, seqlen_k_rounded,
+                     num_heads, num_heads_k,
+                     head_size, head_size_rounded,
+                     q_padded, kcache_padded, vcache_padded, out,
+                     /*cu_seqlens_q_d=*/nullptr,
+                     /*cu_seqlens_k_d=*/nullptr,
+                     /*p_ptr=*/nullptr,
+                     softmax_lse.data_ptr(),
+                     /*p_dropout=*/0.f,
+                     softmax_scale,
+                     is_causal);
+
+    at::Tensor k, v, k_padded, v_padded;
+    if (k_.has_value()) {
+        TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in");
+        TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in");
+        TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache");
+        k = k_.value();
+        v = v_.value();
+        TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query");
+        TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query");
+        TORCH_CHECK(k.is_cuda(), "Key tensor must be on CUDA device");
+        TORCH_CHECK(v.is_cuda(), "Value tensor must be on CUDA device");
+        TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
+        TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
+        CHECK_SHAPE(k, batch_size, seqlen_q, num_heads_k, head_size_og);
+        CHECK_SHAPE(v, batch_size, seqlen_q, num_heads_k, head_size_og);
+        if (head_size_og % 8 != 0) {
+            k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+            v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+        } else {
+            k_padded = k;
+            v_padded = v;
+        }
+        params.knew_ptr = k_padded.data_ptr();
+        params.vnew_ptr = v_padded.data_ptr();
+        // All stride are in elements, not bytes.
+        params.knew_batch_stride = k_padded.stride(0);
+        params.vnew_batch_stride = v_padded.stride(0);
+        params.knew_row_stride = k_padded.stride(-3);
+        params.vnew_row_stride = v_padded.stride(-3);
+        params.knew_head_stride = k_padded.stride(-2);
+        params.vnew_head_stride = v_padded.stride(-2);
+    }
+
+    if (seqlens_k_.has_value()) {
+        auto seqlens_k = seqlens_k_.value();
+        TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
+        TORCH_CHECK(seqlens_k.is_cuda(), "seqlens_k must be on CUDA device");
+        TORCH_CHECK(seqlens_k.is_contiguous(), "seqlens_k must be contiguous");
+        CHECK_SHAPE(seqlens_k, batch_size);
+        params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
+    }
+    params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());
+
+    // This needs to match with run_mha_fwd_splitkv_dispatch
+    const int block_n = is_sm90 || is_sm8x
+        ? (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64))
+        : (head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64));
+    const int num_n_blocks = (seqlen_k + (params.knew_ptr == nullptr ? 0 : seqlen_q) + block_n - 1) / block_n;
+    // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
+    // In any case we don't expect seqlen_q to be larger than 64 for inference.
+    const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
+    params.num_splits = num_splits;
+    if (num_splits < 1) {
+        params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
+    }
+    if (params.num_splits > 1) {
+        at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
+        at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
+        params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
+        params.oaccum_ptr = out_accum.data_ptr();
+    }
+
+    auto stream = at::cuda::getCurrentCUDAStream().stream();
+    // Only split kernel supports appending to KV cache
+    run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value());
+
+    if (head_size_og % 8 != 0) {
+        out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
+        if (out_.has_value()) { out_.value().copy_(out); }
+        if (k_.has_value()) {
+            // It's expensive to copy the KV cache here for the case where head size not divisible by 8,
+            // but we don't expect to get this case in practice. This is just so that the code works for that case.
+            kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
+            vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
+        }
+    }
+
+    return {out, softmax_lse};
+}
+
 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
     m.doc() = "FlashAttention";
     m.def("fwd", &mha_fwd, "Forward pass");
     m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
     m.def("bwd", &mha_bwd, "Backward pass");
     m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
+    m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
 }

+ 7 - 2
csrc/flash_attn/src/block_info.h

@@ -14,9 +14,12 @@ struct BlockInfo {
     template<typename Params>
     __device__ BlockInfo(const Params &params, const int bidb)
         : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
-        , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb])
+        , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
         , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
-        , actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k)
+        // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
+        // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
+        , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
+        , actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_q))
         {
         }
 
@@ -33,6 +36,8 @@ struct BlockInfo {
     const int sum_s_q;
     const int sum_s_k;
     const int actual_seqlen_q;
+    // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
+    const int seqlen_k_cache;
     const int actual_seqlen_k;
 };
 

+ 16 - 0
csrc/flash_attn/src/flash.h

@@ -80,6 +80,18 @@ struct Flash_fwd_params : public Qkv_params {
 
     int *__restrict__ blockmask;
 
+    // The K_new and V_new matrices.
+    void * __restrict__ knew_ptr;
+    void * __restrict__ vnew_ptr;
+
+    // The stride between rows of the Q, K and V matrices.
+    index_t knew_batch_stride;
+    index_t vnew_batch_stride;
+    index_t knew_row_stride;
+    index_t vnew_row_stride;
+    index_t knew_head_stride;
+    index_t vnew_head_stride;
+
     // The dropout probability (probability of keeping an activation).
     float p_dropout;
     // uint32_t p_dropout_in_uint;
@@ -99,6 +111,10 @@ struct Flash_fwd_params : public Qkv_params {
     bool is_bf16;
     bool is_causal;
 
+    // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
+    // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
+    bool is_seqlens_k_cumulative;
+
     int num_splits;  // For split-KV version
 };
 

+ 150 - 61
csrc/flash_attn/src/flash_fwd_kernel.h

@@ -617,7 +617,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, typename Params>
+template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
 inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
 
     using Element = typename Kernel_traits::Element;
@@ -635,7 +635,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
     constexpr int kHeadDim = Kernel_traits::kHeadDim;
     constexpr int kNWarps = Kernel_traits::kNWarps;
 
+    using GmemTiledCopyO = std::conditional_t<
+        !Split,
+        typename Kernel_traits::GmemTiledCopyOaccum,
+        typename Kernel_traits::GmemTiledCopyO
+    >;
+    using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
+
     const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
+    // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); }
+    // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_q = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_q)); }
     if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
 
     const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;
@@ -649,19 +658,21 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         // We exit early and write 0 to gOaccum and -inf to gLSEaccum.
         // Otherwise we might read OOB elements from gK and gV,
         // or get wrong results when we combine gOaccum from different blocks.
+        const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+            + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
         const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
             + m_block * kBlockM) * params.d_rounded;
         const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
-        Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
-                                    Shape<Int<kBlockM>, Int<kHeadDim>>{},
-                                    Stride<Int<kHeadDim>, _1>{});
-        Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lseaccum),
+        Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
+                                      Shape<Int<kBlockM>, Int<kHeadDim>>{},
+                                     make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
+        Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),
                                       Shape<Int<kBlockM>>{}, Stride<_1>{});
 
-        typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
+        GmemTiledCopyO gmem_tiled_copy_Oaccum;
         auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
         Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
-        Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
+        Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
         clear(tOrOaccum);
         // Construct identity layout for sO
         Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum)));    // (BLK_M,BLK_K) -> (blk_m,blk_k)
@@ -679,7 +690,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         #pragma unroll
         for (int m = 0; m < size<1>(tOgOaccum); ++m) {
             const int row = get<0>(tOcO(0, m, 0));
-            if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = -INFINITY; }
+            if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = Split ? -INFINITY : INFINITY; }
         }
         return;
     }
@@ -695,6 +706,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
     const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
         + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
+    const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
+        + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
+    const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
+        + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
 
     Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
                             Shape<Int<kBlockM>, Int<kHeadDim>>{},
@@ -702,15 +717,26 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
     Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
                             Shape<Int<kBlockN>, Int<kHeadDim>>{},
                             make_stride(params.k_row_stride, _1{}));
+    // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); }
     Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
                             Shape<Int<kBlockN>, Int<kHeadDim>>{},
                             make_stride(params.v_row_stride, _1{}));
+    // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
+    // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
+    // This maps to accessing the first 64 rows of knew_ptr.
+    Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.knew_ptr)
+                                             + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride),
+                               Shape<Int<kBlockN>, Int<kHeadDim>>{},
+                               make_stride(params.knew_row_stride, _1{}));
+    // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
+    Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.vnew_ptr)
+                                             + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
+                               Shape<Int<kBlockN>, Int<kHeadDim>>{},
+                               make_stride(params.vnew_row_stride, _1{}));
 
     Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
                             typename Kernel_traits::SmemLayoutQ{});
-    // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
-    Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
-                            typename Kernel_traits::SmemLayoutKV{});
+    Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{});
     Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
     Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
     Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
@@ -721,8 +747,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
     Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
     Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
     Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK);  // (KCPY, KCPY_N, KCPY_K)
+    Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew);  // (KCPY, KCPY_N, KCPY_K)
     Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
     Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV);  // (VCPY, VCPY_N, VCPY_K)
+    Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew);  // (VCPY, VCPY_N, VCPY_K)
     Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
 
     typename Kernel_traits::TiledMma tiled_mma;
@@ -787,32 +815,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
     // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
     flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
                                        binfo.actual_seqlen_q - m_block * kBlockM);
-    if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
-
-    if (Kernel_traits::Share_Q_K_smem) {
-        flash::cp_async_wait<0>();
-        __syncthreads();
-        Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
-        CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view));            // M
-        cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
-        __syncthreads();
-    }
 
     int n_block = n_block_max - 1;
     // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
-    flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
-                                       binfo.actual_seqlen_k - n_block * kBlockN);
+    flash::copy_2_sources</*Is_2_sources=*/Append_KV, Is_even_MN, Is_even_K>(
+        gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV,
+        binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
+    );
     cute::cp_async_fence();
-    // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
-    // __syncthreads();
 
-    if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
-        flash::cp_async_wait<1>();
-        __syncthreads();
-        Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
-        CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view));            // M
-        cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
-    }
+    // flash::cp_async_wait<0>();
+    // __syncthreads();
+    // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); }
+    // __syncthreads();
 
     clear(acc_o);
 
@@ -834,19 +849,37 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         flash::cp_async_wait<0>();
         __syncthreads();
 
+        if constexpr (Append_KV) {
+            // if (cute::thread0()) { print(tKgK); }
+            // if (cute::thread0()) { print(tKsK); }
+            // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
+            if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
+                flash::copy_w_min_idx<Is_even_K>(
+                    tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
+                );
+            }
+            // __syncthreads();
+            // if (cute::thread0()) { print(tKgK); }
+            // __syncthreads();
+        }
+
         // Advance gV
         if (masking_step > 0) {
             tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
-            flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
+            if (Append_KV) { tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); }
+            flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
+                gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN
+            );
         } else {
             // Clear the smem tiles to account for predicated off loads
-            flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
-                gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
+            flash::copy_2_sources</*Is_2_sources=*/Append_KV, Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
+                gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV,
+                binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
             );
         }
         cute::cp_async_fence();
 
-        flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
+        flash::gemm(
             acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
             smem_thr_copy_Q, smem_thr_copy_K
         );
@@ -869,19 +902,39 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
 
         flash::cp_async_wait<0>();
         __syncthreads();
+        // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }
+        // __syncthreads();
+
+        // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("n_block = %d, n_block_min = %d\n", n_block, n_block_min); }
+        if constexpr (Append_KV) {
+            // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
+            if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
+                flash::copy_w_min_idx<Is_even_K>(
+                    tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
+                );
+            }
+        }
+
         if (n_block > n_block_min) {
             // Advance gK
+            // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p\n", tKgKnew.data()); }
             tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
-            flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
+            if (Append_KV) { tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); }
+            // if (tidx == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("tKgKnew = %p, row_idx_switch = %d\n", tKgKnew.data(), binfo.seqlen_k_cache - (n_block - 1) * kBlockN); }
+            flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
+                gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0,
+                binfo.seqlen_k_cache - (n_block - 1) * kBlockN
+            );
             // This cp_async_fence needs to be in the if block, otherwise the synchronization
             // isn't right and we get race conditions.
             cute::cp_async_fence();
         }
 
-        // TODO: when we have key_padding_mask we'll need to Check_inf
+        // We have key_padding_mask so we'll need to Check_inf
         masking_step == 0
-            ? softmax_rescale_o</*Is_first=*/true,  /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
-            : softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
+            ? softmax_rescale_o</*Is_first=*/true,  /*Check_inf=*/Is_causal || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
+            : softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
+        // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
 
         // Convert scores from fp32 to fp16/bf16
         Tensor rP = flash::convert_type<Element>(scores);
@@ -905,22 +958,45 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         clear(acc_s);
         flash::cp_async_wait<0>();
         __syncthreads();
+        if constexpr (Append_KV) {
+            // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("n_split_idx = %d, bidh = %d, params.h_h_k_ratio = %d, seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", n_split_idx, bidh, params.h_h_k_ratio, binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
+            if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
+                flash::copy_w_min_idx<Is_even_K>(
+                    tKsK, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
+                );
+            }
+        }
         // Advance gV
         tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
-        flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
+        if (Append_KV) { tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); }
+        flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
+            gmem_tiled_copy_QKV, tVgV, tVgVnew, tVsV, tKVcKV, tKVpKV, 0, binfo.seqlen_k_cache - n_block * kBlockN
+        );
         cute::cp_async_fence();
 
-        flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
+        flash::gemm(
             acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
             smem_thr_copy_Q, smem_thr_copy_K
         );
 
         flash::cp_async_wait<0>();
         __syncthreads();
+        if constexpr (Append_KV) {
+            // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("seqlen_k_cache = %d, (nblock + 1) * kBlockN = %d\n", binfo.seqlen_k_cache, (n_block + 1) * kBlockN); }
+            if (bidh % params.h_h_k_ratio == 0 && binfo.seqlen_k_cache < (n_block + 1) * kBlockN) {
+                flash::copy_w_min_idx<Is_even_K>(
+                    tVsV, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
+                );
+            }
+        }
         if (n_block > n_block_min) {
             // Advance gK
             tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
-            flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
+            if (Append_KV) { tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); }
+            flash::copy_2_sources</*Is_2_sources=*/Append_KV, /*Is_even_MN=*/true, Is_even_K>(
+                gmem_tiled_copy_QKV, tKgK, tKgKnew, tKsK, tKVcKV, tKVpKV, 0,
+                binfo.seqlen_k_cache - (n_block - 1) * kBlockN
+            );
             // This cp_async_fence needs to be in the if block, otherwise the synchronization
             // isn't right and we get race conditions.
             cute::cp_async_fence();
@@ -942,49 +1018,60 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
 
     // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
     Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
+    // if (cute::thread0()) { print(acc_o_rowcol); }
     Tensor lse = make_fragment_like(scores_sum);
     #pragma unroll
     for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
         float sum = scores_sum(mi);
         float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
-        lse(mi) = (sum == 0.f || sum != sum) ? -INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum);
+        lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : scores_max(mi) * params.scale_softmax + __logf(sum);
         float scale = inv_sum;
         #pragma unroll
         for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
     }
-
+    // if (cute::thread0()) { print(lse); }
     // if (cute::thread0()) { print(acc_o_rowcol); }
 
-    Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementAccum *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
+    Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
     // Partition sO to match the accumulator partitioning
-    auto smem_tiled_copy_Oaccum = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomOaccum{}, tiled_mma);
+    using SmemTiledCopyO = std::conditional_t<
+        !Split,
+        typename Kernel_traits::SmemCopyAtomO,
+        typename Kernel_traits::SmemCopyAtomOaccum
+    >;
+    auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma);
     auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
-    Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(acc_o);        // ((Atom,AtomNum), MMA_M, MMA_N)
+    Tensor rO = flash::convert_type<ElementO>(acc_o);
+    Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO);        // ((Atom,AtomNum), MMA_M, MMA_N)
     Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum);     // ((Atom,AtomNum),PIPE_M,PIPE_N)
 
-    // sO has the same size as sQ, so we don't need to sync here.
-    if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }
+    // sOaccum is larger than sQ, so we need to syncthreads here
+    // TODO: allocate enough smem for sOaccum
+    if constexpr (Split) { __syncthreads(); }
 
     cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
 
+    const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+        + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
     const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
                                          + m_block * kBlockM) * params.d_rounded;
     const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
 
-    Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
+    Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
                                  Shape<Int<kBlockM>, Int<kHeadDim>>{},
-                                 Stride<Int<kHeadDim>, _1>{});
-    Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lseaccum),
+                                 make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
+    Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),
                                    Shape<Int<kBlockM>>{}, Stride<_1>{});
+    // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); }
 
-    typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
+    GmemTiledCopyO gmem_tiled_copy_Oaccum;
     auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
     Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum);        // ((Atom,AtomNum),ATOM_M,ATOM_N)
     Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
 
     __syncthreads();
 
-    Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
+    Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
     cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
 
     Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});    // (BLK_M,BLK_K) -> (blk_m,blk_k)
@@ -1014,6 +1101,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
     flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
         gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
     );
+    // __syncthreads();
+    // if (cute::thread0()) { print(tOgOaccum); }
 }
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -1039,16 +1128,16 @@ inline __device__ void compute_attn(const Params &params) {
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, typename Params>
+template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
 inline __device__ void compute_attn_splitkv(const Params &params) {
     const int m_block = blockIdx.x;
     // The block index for the batch.
-    const int bidb = blockIdx.z / params.h;
+    const int bidb = Split ? blockIdx.z / params.h : blockIdx.y;
     // The block index for the head.
-    const int bidh = blockIdx.z - bidb * params.h;
-    const int n_split_idx = blockIdx.y;
-    const int num_n_splits = gridDim.y;
-    flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
+    const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
+    const int n_split_idx = Split ? blockIdx.y : 0;
+    const int num_n_splits = Split ? gridDim.y : 1;
+    flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
 }
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////

+ 41 - 31
csrc/flash_attn/src/flash_fwd_launch_template.h

@@ -15,9 +15,9 @@ __global__ void flash_fwd_kernel(Flash_fwd_params params) {
     flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, Return_softmax>(params);
 }
 
-template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K>
+template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV>
 __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) {
-    flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K>(params);
+    flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Split, Append_KV>(params);
 }
 
 template<typename Kernel_traits, int Log_max_splits, bool Is_even_K>
@@ -63,45 +63,55 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
 
 template<typename Kernel_traits>
 void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
+    static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs");
+    static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem");
     constexpr size_t smem_size = Kernel_traits::kSmemSize;
     const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
-    dim3 grid(num_m_block, params.num_splits, params.b * params.h);
+    dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h);
     const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
     const bool is_even_K = params.d == Kernel_traits::kHeadDim;
-    // TODO: do we want to guarantee that seqlen_q <= seqlen_k? That would simplify the kernel a bit.
     BOOL_SWITCH(params.is_causal, Is_causal, [&] {
         BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
             BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
-                auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, IsEvenMNConst, IsEvenKConst>;
-                // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
-                if (smem_size >= 48 * 1024) {
-                    C10_CUDA_CHECK(cudaFuncSetAttribute(
-                        kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
-                }
-                kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
-                C10_CUDA_KERNEL_LAUNCH_CHECK();
+                BOOL_SWITCH(params.num_splits > 1, Split, [&] {
+                    BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
+                        // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
+                        // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr);
+                        auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV, IsEvenKConst, Split, Append_KV>;
+                        // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
+                        // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
+                        if (smem_size >= 48 * 1024) {
+                            C10_CUDA_CHECK(cudaFuncSetAttribute(
+                                kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+                        }
+                        kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
+                        C10_CUDA_KERNEL_LAUNCH_CHECK();
+                    });
+                });
             });
         });
     });
-    dim3 grid_combine((params.b * params.h * params.seqlen_q + 16 - 1) / 16);
-    BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
-        if (params.num_splits <= 2) {
-            flash_fwd_splitkv_combine_kernel<Kernel_traits, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
-        } else if (params.num_splits <= 4) {
-            flash_fwd_splitkv_combine_kernel<Kernel_traits, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
-        } else if (params.num_splits <= 8) {
-            flash_fwd_splitkv_combine_kernel<Kernel_traits, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
-        } else if (params.num_splits <= 16) {
-            flash_fwd_splitkv_combine_kernel<Kernel_traits, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
-        } else if (params.num_splits <= 32) {
-            flash_fwd_splitkv_combine_kernel<Kernel_traits, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
-        } else if (params.num_splits <= 64) {
-            flash_fwd_splitkv_combine_kernel<Kernel_traits, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
-        // } else if (params.num_splits <= 128) {
-        //     flash_fwd_splitkv_combine_kernel<Kernel_traits, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
-        }
-        C10_CUDA_KERNEL_LAUNCH_CHECK();
-    });
+    if (params.num_splits > 1) {
+        dim3 grid_combine((params.b * params.h * params.seqlen_q + 16 - 1) / 16);
+        BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
+            if (params.num_splits <= 2) {
+                flash_fwd_splitkv_combine_kernel<Kernel_traits, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+            } else if (params.num_splits <= 4) {
+                flash_fwd_splitkv_combine_kernel<Kernel_traits, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+            } else if (params.num_splits <= 8) {
+                flash_fwd_splitkv_combine_kernel<Kernel_traits, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+            } else if (params.num_splits <= 16) {
+                flash_fwd_splitkv_combine_kernel<Kernel_traits, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+            } else if (params.num_splits <= 32) {
+                flash_fwd_splitkv_combine_kernel<Kernel_traits, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+            } else if (params.num_splits <= 64) {
+                flash_fwd_splitkv_combine_kernel<Kernel_traits, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+            } else if (params.num_splits <= 128) {
+                flash_fwd_splitkv_combine_kernel<Kernel_traits, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+            }
+            C10_CUDA_KERNEL_LAUNCH_CHECK();
+        });
+    }
 }
 
 template<typename T, int Headdim>

+ 69 - 2
csrc/flash_attn/src/utils.h

@@ -291,7 +291,7 @@ template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bo
           typename Engine2, typename Layout2, typename Engine3, typename Layout3>
 inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
                             Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
-                            Tensor<Engine3, Layout3> const &predicate_K, int max_MN=0) {
+                            Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
     CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
     CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
     CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));                     // MMA
@@ -355,4 +355,71 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-}  // namespace flash
+template <bool Is_2_sources=false, bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
+          typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
+          typename Engine2, typename Layout2, typename Engine3, typename Layout3>
+inline __device__ void copy_2_sources(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S0,
+                                      Tensor<Engine0, Layout0> const &S1,
+                                      Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
+                                      Tensor<Engine3, Layout3> const &predicate_K,
+                                      const int max_MN=0, const int row_idx_switch=0) {
+    CUTE_STATIC_ASSERT_V(rank(S0) == Int<3>{} && rank(S1) == Int<3>{});
+    CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
+    CUTE_STATIC_ASSERT_V(size<0>(S0) == size<0>(D) && size<0>(S1) == size<0>(D)); // MMA
+    CUTE_STATIC_ASSERT_V(size<1>(S0) == size<1>(D) && size<1>(S1) == size<1>(D)); // MMA_M
+    CUTE_STATIC_ASSERT_V(size<2>(S0) == size<2>(D) && size<2>(S1) == size<2>(D)); // MMA_K
+    // There's no case where !Clear_OOB_K && Clear_OOB_MN
+    static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
+    // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", Is_2_sources, max_MN, row_idx_switch); }
+    // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", blockIdx.y, Is_2_sources, max_MN, row_idx_switch); }
+    #pragma unroll
+    for (int m = 0; m < size<1>(S0); ++m) {
+        auto &S = !Is_2_sources || get<0>(identity_MN(0, m, 0)) < row_idx_switch ? S0 : S1;
+        if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
+            #pragma unroll
+            for (int k = 0; k < size<2>(S0); ++k) {
+                if (Is_even_K || predicate_K(k)) {
+                    cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
+                } else if (Clear_OOB_K) {
+                    cute::clear(D(_, m, k));
+                }
+            }
+        } else if (Clear_OOB_MN) {
+            cute::clear(D(_, m, _));
+        }
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <bool Is_even_K=true,
+          typename Engine0, typename Layout0, typename Engine1, typename Layout1,
+          typename Engine2, typename Layout2, typename Engine3, typename Layout3>
+inline __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
+                                      Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
+                                      Tensor<Engine3, Layout3> const &predicate_K,
+                                      const int max_MN=0, const int min_MN=0) {
+    CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
+    CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
+    CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));                     // MMA
+    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));                     // MMA_M
+    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));                     // MMA_K
+    // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
+    #pragma unroll
+    for (int m = 0; m < size<1>(S); ++m) {
+        // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
+        if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
+            // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
+            #pragma unroll
+            for (int k = 0; k < size<2>(S); ++k) {
+                if (Is_even_K || predicate_K(k)) {
+                    cute::copy(S(_, m, k), D(_, m, k));
+                }
+            }
+        }
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+}  // namespace flash

+ 1 - 0
flash_attn/__init__.py

@@ -7,4 +7,5 @@ from flash_attn.flash_attn_interface import (
     flash_attn_varlen_func,
     flash_attn_varlen_kvpacked_func,
     flash_attn_varlen_qkvpacked_func,
+    flash_attn_with_kvcache,
 )

+ 72 - 0
flash_attn/flash_attn_interface.py

@@ -5,6 +5,7 @@ from einops import rearrange
 # isort: off
 # We need to import the CUDA kernels after importing torch
 import flash_attn_2_cuda as flash_attn_cuda
+
 # isort: on
 
 
@@ -790,3 +791,74 @@ def flash_attn_varlen_func(
         causal,
         return_attn_probs,
     )
+
+
+def flash_attn_with_kvcache(
+    q,
+    k_cache,
+    v_cache,
+    k=None,
+    v=None,
+    cache_seqlens=None,
+    softmax_scale=None,
+    causal=False,
+    num_splits=0,
+):
+    """
+    If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
+    k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
+    the previous step, and update them with the new keys/values from the current step, and do
+    attention with the updated cache, all in 1 kernel.
+
+    If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
+    For example, the KV cache could be pre-allocated with the max sequence length, and you can use
+    cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
+
+    Does not support backward pass.
+
+    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
+    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
+    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
+    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
+
+    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
+    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
+        1 1 1 1 0
+        1 1 1 1 1
+    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
+        0 0
+        0 0
+        0 0
+        1 0
+        1 1
+    If the row of the mask is all zero, the output will be zero.
+
+    Arguments:
+        q: (batch_size, seqlen, nheads, headdim)
+        k_cache: (batch_size, seqlen_cache, nheads_k, headdim)
+        v_cache: (batch_size, seqlen_cache, nheads_k, headdim)
+        k [optional]: (batch_size, seqlen, nheads_k, headdim). If not None, we concatenate k with
+            k_cache, starting at the indices specified by cache_seqlens.
+        v [optional]: (batch_size, seqlen, nheads_k, headdim). Similar to k.
+        cache_seqlens: (batch_size,), dtype torch.int32. The sequence lengths of the KV cache.
+        softmax_scale: float. The scaling of QK^T before applying softmax.
+            Default to 1 / sqrt(headdim).
+        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
+        num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
+           If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
+           to automatically determine the number of splits.
+           Don't change this unless you know what you are doing.
+
+    Return:
+        out: (batch_size, seqlen, nheads, headdim).
+    """
+    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
+    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
+    maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
+    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
+    if softmax_scale is None:
+        softmax_scale = q.shape[-1] ** (-0.5)
+    out, softmax_lse = flash_attn_cuda.fwd_kvcache(
+        q, k_cache, v_cache, k, v, cache_seqlens, None, softmax_scale, causal, num_splits
+    )
+    return out

+ 25 - 9
flash_attn/utils/generation.py

@@ -348,8 +348,14 @@ def decode_speculative(
     )
 
     def sample_tokens(
-        input_ids, model, inference_params, sample_fn, num_tokens=1, cg=False, decoding=True,
-        last_token_logits=False
+        input_ids,
+        model,
+        inference_params,
+        sample_fn,
+        num_tokens=1,
+        cg=False,
+        decoding=True,
+        last_token_logits=False,
     ):
         """Sample `num_tokens` tokens from the model, given the previous logits.
         Also return the logits of the sampled tokens.
@@ -374,12 +380,18 @@ def decode_speculative(
         sequences = []
         if decoding:
             assert seqlen == 1
-            position_ids = torch.full(
-                (batch_size, 1),
-                inference_params.sequence_len_offset,
-                dtype=torch.long,
-                device=input_ids.device,
+            position_ids = repeat(
+                torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
+                + inference_params.sequence_len_offset,
+                "s -> b s",
+                b=batch_size,
             )
+            # position_ids = torch.full(
+            #     (batch_size, 1),
+            #     inference_params.sequence_len_offset,
+            #     dtype=torch.long,
+            #     device=input_ids.device,
+            # )
         else:
             position_ids = None
         logits = logits_postprocess_fn(
@@ -399,7 +411,11 @@ def decode_speculative(
                 )
                 logits = logits_postprocess_fn(
                     logits_forward_fn(
-                        model, rearrange(next_token, "b -> b 1"), position_ids, inference_params, cg=cg
+                        model,
+                        rearrange(next_token, "b -> b 1"),
+                        position_ids,
+                        inference_params,
+                        cg=cg,
                     )
                 )
                 inference_params.sequence_len_offset += 1
@@ -420,7 +436,7 @@ def decode_speculative(
         sample_fn=sample_fn,
         last_token_logits=True,
         inference_params=inference_params_draft,
-        cg=cg
+        cg=cg,
     )
 
     if debug:

+ 90 - 0
tests/test_flash_attn.py

@@ -11,6 +11,7 @@ from flash_attn import (
     flash_attn_varlen_func,
     flash_attn_varlen_kvpacked_func,
     flash_attn_varlen_qkvpacked_func,
+    flash_attn_with_kvcache,
 )
 from flash_attn.bert_padding import pad_input, unpad_input
 from flash_attn.flash_attn_interface import _get_block_size
@@ -1465,6 +1466,95 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
         assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 2e-4
         assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 2e-4
 
+@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
+# @pytest.mark.parametrize("dtype", [torch.float16])
+@pytest.mark.parametrize("num_splits", [1, 0])
+# @pytest.mark.parametrize("num_splits", [0])
+@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
+# @pytest.mark.parametrize("mha_type", ["mqa"])
+@pytest.mark.parametrize("new_kv", [False, True])
+# @pytest.mark.parametrize("new_kv", [False])
+@pytest.mark.parametrize("causal", [False, True])
+# @pytest.mark.parametrize("causal", [True])
+@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
+# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
+# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
+# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
+# @pytest.mark.parametrize('d', [56, 80])
+# @pytest.mark.parametrize("d", [64])
+@pytest.mark.parametrize(
+    "seqlen_q,seqlen_k",
+    [
+        (1, 128),
+        (1, 339),
+        (3, 1024),
+        (64, 800),
+        (64, 256),
+        (3, 799),
+        (64, 2048),
+        (16, 20000),
+        (1, 128 * 1024),
+        (16, 128 * 1024),
+        (128, 128),
+    ],
+)
+# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
+def test_flash_attn_kvcache(seqlen_q, seqlen_k, d, causal, new_kv, mha_type, num_splits, dtype):
+    if seqlen_q > seqlen_k and new_kv:
+        pytest.skip()
+    device = "cuda"
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 2
+    nheads = 6
+    nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
+    assert nheads % nheads_k == 0
+    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
+    if new_kv:
+        k = torch.randn(batch_size, seqlen_q, nheads_k, d, device=device, dtype=dtype)
+        v = torch.randn(batch_size, seqlen_q, nheads_k, d, device=device, dtype=dtype)
+    else:
+        k, v = None, None
+    k_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
+    v_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
+    cache_seqlens = torch.randint(0, (seqlen_k - seqlen_q + 1) if new_kv else (seqlen_k + 1), (batch_size, ), dtype=torch.int32, device=device)
+    # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
+    # k_cache[:, 64:] = -1
+    k_cache_ref = k_cache.clone()
+    v_cache_ref = v_cache.clone()
+    arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
+    cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
+    if new_kv:
+        update_mask = torch.logical_and(cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_q)
+        k_cache_ref[update_mask] = rearrange(k, "b s ... -> (b s) ...")
+        v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...")
+    k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
+    v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
+    out = flash_attn_with_kvcache(q, k_cache, v_cache, k, v, cache_seqlens, causal=causal, num_splits=num_splits)
+    # out = flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal)
+    # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal)
+    # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
+    # m = qk.amax(-1, keepdim=True)
+    # s_tmp = torch.exp((qk - m) / math.sqrt(d))
+    # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
+    # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
+    # probs = torch.softmax(qk, dim=-1)
+    key_padding_mask = arange < cache_seqlens_expanded + (seqlen_q if new_kv else 0)
+    out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal)
+    out_pt, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal,
+                              upcast=False, reorder_ops=True)
+    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
+    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
+    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
+    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
+
+    # Check that FlashAttention's numerical error is at most twice the numerical error
+    # of a Pytorch implementation.
+    assert (out - out_ref).abs().max().item() <= 3 * (out_pt - out_ref).abs().max().item() + 1e-5
+    if new_kv:
+        assert torch.equal(k_cache, k_cache_ref)
+        assert torch.equal(v_cache, v_cache_ref)
+
 
 # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
 @pytest.mark.parametrize("dtype", [torch.float16])