Bläddra i källkod

Support page kvcache in AMD ROCm (#1198)

* Integrate ck branch of ck_tile/fa_bwd_opt

* Assume dq and q share the same stride

* update ck

* Integrate more stride of dq_acc

* Revert fwd dropout

* Fix paremeter order

* Integrate ck with more stride

* update the limit of hdim of bwd

* Check argument

* Add test_flash_attn_causal

* Support unpad lse

* Add  test_flash_attn_varlen_causal, test_flash_attn_race_condition, test_flash_attn_bwd_overflow, test_flash_attn_bwd_transpose, test_flash_attn_bwd_varlen_overflow, test_flash_attn_deterministic, test_flash_attn_varlen_deterministic

* Fix stride and Kn0

* Fix CK sync issue

* Fix typo

* Update CK for changing of fmha_fwd_args

* Add kvcache tmp

* Add kvcache

* Fix comment

* Sync behavior with ck

* Update CK to develop

* remove large test case

* Add kvcache test

* Fix page_block_size in arg

* Minor fix

* Fix stride error

* Update seqlen of kvcache before splitkv

* Fix compile error

* Fix bug of hdim is not 8x

* Fit ck arg

* support adaptive num_splits

* add more tests

* Refine test tolerance

* update CK

* Move override_num_splits_if_necessary into cpp

* update ck

* Update ck

* Support different flag for different version of hip

* remove coerce-illegal, becasue this is not required in FA

* Update ck to fix xcratch memory

* Add coerce-illegal in some version

* Add compile flag for rtn rounding

* remove redundant init

* Using env var to switch rounding mode

* update ck
rocking 6 månader sedan
förälder
incheckning
e2182cc21d

+ 1 - 1
csrc/composable_kernel

@@ -1 +1 @@
-Subproject commit 8182976c37433808b5e3a27a6536d1b74b0c23a1
+Subproject commit a9b170b54195ab667ca814f80dd5dfbf4ad772f5

+ 32 - 9
csrc/flash_attn_ck/flash_api.cpp

@@ -20,16 +20,16 @@ mha_fwd(at::Tensor &q,
         c10::optional<at::Generator> gen_);
 
 std::vector<at::Tensor>
-mha_varlen_fwd(at::Tensor &q,                            // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
-               const at::Tensor &k,                      // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
-               const at::Tensor &v,                      // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
-               c10::optional<at::Tensor> &out_,          // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
-               const at::Tensor &cu_seqlens_q,           // b+1
-               const at::Tensor &cu_seqlens_k,           // b+1
-               c10::optional<at::Tensor> &seqused_k,     // b. If given, only this many elements of each batch element's keys are used.
+mha_varlen_fwd(at::Tensor &q,                               // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
+               const at::Tensor &k,                         // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
+               const at::Tensor &v,                         // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
+               c10::optional<at::Tensor> &out_,             // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
+               const at::Tensor &cu_seqlens_q,              // b+1
+               const at::Tensor &cu_seqlens_k,              // b+1
+               c10::optional<at::Tensor> &seqused_k,        // b. If given, only this many elements of each batch element's keys are used.
                c10::optional<const at::Tensor> &leftpad_k_, // batch_size
-               c10::optional<at::Tensor> &block_table_,  // batch_size x max_num_blocks_per_seq
-               c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
+               c10::optional<at::Tensor> &block_table_,     // batch_size x max_num_blocks_per_seq
+               c10::optional<at::Tensor> &alibi_slopes_,    // num_heads or b x num_heads
                int max_seqlen_q,
                const int max_seqlen_k,
                const float p_dropout,
@@ -89,6 +89,28 @@ mha_varlen_bwd(const at::Tensor &dout,                   // total_q x num_heads
                c10::optional<at::Generator> gen_,
                c10::optional<at::Tensor> &rng_state);
 
+std::vector<at::Tensor>
+mha_fwd_kvcache(at::Tensor &q,                                     // batch_size x seqlen_q x num_heads x head_size
+                const at::Tensor &kcache,                          // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
+                const at::Tensor &vcache,                          // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
+                c10::optional<const at::Tensor> &k_,               // batch_size x seqlen_knew x num_heads_k x head_size
+                c10::optional<const at::Tensor> &v_,               // batch_size x seqlen_knew x num_heads_k x head_size
+                c10::optional<const at::Tensor> &seqlens_k_,       // batch_size
+                c10::optional<const at::Tensor> &rotary_cos_,      // seqlen_ro x (rotary_dim / 2)
+                c10::optional<const at::Tensor> &rotary_sin_,      // seqlen_ro x (rotary_dim / 2)
+                c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
+                c10::optional<const at::Tensor> &leftpad_k_,       // batch_size
+                c10::optional<at::Tensor> &block_table_,           // batch_size x max_num_blocks_per_seq
+                c10::optional<at::Tensor> &alibi_slopes_,          // num_heads or batch_size x num_heads
+                c10::optional<at::Tensor> &out_,                   // batch_size x seqlen_q x num_heads x head_size
+                const float softmax_scale,
+                bool is_causal,
+                int window_size_left,
+                int window_size_right,
+                const float softcap,
+                bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
+                int num_splits);
+
 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
 {
     m.doc() = "FlashAttention";
@@ -96,4 +118,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
     m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
     m.def("bwd", &mha_bwd, "Backward pass");
     m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
+    m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
 }

+ 34 - 0
csrc/flash_attn_ck/flash_common.cpp

@@ -0,0 +1,34 @@
+/******************************************************************************
+ * Copyright (c) 2024, Tri Dao.
+ ******************************************************************************/
+
+#include "flash_common.hpp"
+
+namespace flash {
+int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
+{
+    int device;
+    auto status = hipGetDevice(&device);
+    if(status != hipSuccess)
+        return num_splits;
+
+    hipDeviceProp_t props{};
+    status = hipGetDeviceProperties(&props, device);
+    if(status != hipSuccess)
+        return num_splits;
+
+    // TODO - tile size should match the TileFmhaShape, hardcode for now
+    const int kM0 = 128;
+    const int kN1 = hdim_v;
+
+    const int num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
+    const int num_n_blocks = (hdim_v + kN1 - 1) / kN1;
+
+    if(num_splits < 1 && p_drop == 0.0f)
+        return num_splits_heuristic_ck(
+            batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
+
+    return num_splits;
+}
+
+} // namespace flash

+ 39 - 1
csrc/flash_attn_ck/flash_common.hpp

@@ -24,7 +24,7 @@
 namespace flash {
 // Copy from PyTorch
 // https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17
-static std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
+inline std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
   if (arg.captured_) {
     // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
     // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
@@ -35,4 +35,42 @@ static std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
   }
 }
 
+inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
+    // If we have enough to almost fill the SMs, then just use 1 split
+    if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
+    max_splits = std::min({max_splits, num_SMs, num_n_blocks});
+    float max_efficiency = 0.f;
+    std::vector<float> efficiency;
+    efficiency.reserve(max_splits);
+    auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
+    // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
+    // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
+    // (i.e. it's 11 splits anyway).
+    // So we check if the number of blocks per split is the same as the previous num_splits.
+    auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
+        return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
+    };
+    for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
+        if (!is_split_eligible(num_splits)) {
+            efficiency.push_back(0.f);
+        } else {
+            float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
+            float eff = n_waves / ceil(n_waves);
+            // printf("num_splits = %d, eff = %f\n", num_splits, eff);
+            if (eff > max_efficiency) { max_efficiency = eff; }
+            efficiency.push_back(eff);
+        }
+    }
+    for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
+        if (!is_split_eligible(num_splits)) { continue; }
+        if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
+            // printf("num_splits chosen = %d\n", num_splits);
+            return num_splits;
+        }
+    }
+    return 1;
+}
+
+int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits);
+
 } // namespace flash

+ 71 - 32
csrc/flash_attn_ck/mha_bwd.cpp

@@ -11,7 +11,8 @@ fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
                                        std::string dtype,
                                        int head_size,
                                        bool has_dropout,
-                                       bool enable_alibi)
+                                       bool enable_alibi,
+                                       bool deterministic)
 {
     return fmha_bwd_traits{head_size,
                            head_size,
@@ -20,7 +21,9 @@ fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
                            mask.type,
                            enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
                            false,    // has_dbias
-                           has_dropout};
+                           has_dropout,
+                           false, // s_randval
+                           deterministic};
 }
 
 fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
@@ -39,6 +42,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
                                    const at::Tensor out,
                                    const at::Tensor softmax_lse,
                                    const at::Tensor dout,
+                                   at::Tensor dq_acc,
                                    at::Tensor d,
                                    at::Tensor dq,
                                    at::Tensor dk,
@@ -49,41 +53,57 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
                                    uint64_t drop_offset)
 {
     // q: (batch_size, seqlen_q, nheads, hdim)
+    ck_tile::index_t batch_stride_q = q.stride(0);
+    ck_tile::index_t stride_q = q.stride(1);
+    ck_tile::index_t nhead_stride_q = q.stride(2);
+
     // k: (batch_size, seqlen_k, nheads_k, hdim)
+    ck_tile::index_t batch_stride_k = k.stride(0);
+    ck_tile::index_t stride_k = k.stride(1);
+    ck_tile::index_t nhead_stride_k = k.stride(2);
+
     // v: (batch_size, seqlen_k, nheads_k, hdim)
+    ck_tile::index_t batch_stride_v = v.stride(0);
+    ck_tile::index_t stride_v = v.stride(1);
+    ck_tile::index_t nhead_stride_v = v.stride(2);
+
     // o: (batch_size, seqlen_q, nheads, hdim)
-    // dq: (batch_size, seqlen_q, nheads, hdim)
-    // dk_expanded: (batch_size, seqlen_k, nheads, hdim)
-    // dv_expanded: (batch_size, seqlen_k, nheads, hdim)
-    // do: (batch_size, seqlen_q, nheads, hdim)
+    ck_tile::index_t batch_stride_o = out.stride(0);
+    ck_tile::index_t stride_o = out.stride(1);
+    ck_tile::index_t nhead_stride_o = out.stride(2);
 
-    // alibi_slopes:(batch_size, nheads) or (nhead)
     // lse: (batch_size, nheads, seqlen_q)
-    // d: (batch_size, nheads, seqlen_q)
+    ck_tile::index_t batch_stride_lse = softmax_lse.stride(0);
+    ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1);
 
-    ck_tile::index_t stride_q = q.stride(1);
-    ck_tile::index_t stride_k = k.stride(1);
-    ck_tile::index_t stride_v = v.stride(1);
-    ck_tile::index_t stride_o = out.stride(1);
+    // do: (batch_size, seqlen_q, nheads, hdim)
+    ck_tile::index_t batch_stride_do = dout.stride(0);
     ck_tile::index_t stride_do = dout.stride(1);
-    ck_tile::index_t stride_dk = dk.stride(1);
-    ck_tile::index_t stride_dv = dv.stride(1);
-
-    ck_tile::index_t nhead_stride_q = q.stride(2);
-    ck_tile::index_t nhead_stride_k = k.stride(2);
-    ck_tile::index_t nhead_stride_v = v.stride(2);
-    ck_tile::index_t nhead_stride_o = out.stride(2);
     ck_tile::index_t nhead_stride_do = dout.stride(2);
-    ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1);
 
-    ck_tile::index_t batch_stride_q = q.stride(0);
-    ck_tile::index_t batch_stride_k = k.stride(0);
-    ck_tile::index_t batch_stride_v = v.stride(0);
-    ck_tile::index_t batch_stride_o = out.stride(0);
-    ck_tile::index_t batch_stride_do = dout.stride(0);
-    ck_tile::index_t batch_stride_lse = softmax_lse.stride(0);
+    // d: (batch_size, nheads, seqlen_q)
+    // CK assume d share the same stride with lse
+
+    // dq: (batch_size, seqlen_q, nheads, hdim)
+    ck_tile::index_t batch_stride_dq = dq.stride(0);
+    ck_tile::index_t stride_dq = dq.stride(1);
+    ck_tile::index_t nhead_stride_dq = dq.stride(2);
+
+    // dk_expanded: (batch_size, seqlen_k, nheads, hdim)
     ck_tile::index_t batch_stride_dk = dk.stride(0);
+    ck_tile::index_t stride_dk = dk.stride(1);
+    ck_tile::index_t nhead_stride_dk = dk.stride(2);
+
+    // dv_expanded: (batch_size, seqlen_k, nheads, hdim)
     ck_tile::index_t batch_stride_dv = dv.stride(0);
+    ck_tile::index_t stride_dv = dv.stride(1);
+    ck_tile::index_t nhead_stride_dv = dv.stride(2);
+
+    // dq_acc: (split, batch_size, seqlen_q, nheads, hdim)
+    ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0);
+    ck_tile::index_t batch_stride_dq_acc = dq_acc.stride(1);
+    ck_tile::index_t stride_dq_acc = dq_acc.stride(2);
+    ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(3);
 
     float p_undrop = 1.0 - p_dropout;
 
@@ -96,6 +116,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
         TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
         TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
         alibi_slopes_ptr = alibi_slopes.data_ptr();
+        // alibi_slopes:(batch_size, nheads) or (nhead)
         stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
     }
 
@@ -112,6 +133,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
                          dk.data_ptr(),
                          dv.data_ptr(),
                          nullptr, // dbias
+                         dq_acc.data_ptr(), // dq_acc
                          nullptr, // seqstart_q
                          nullptr, // seqstart_k
                          nullptr, // seqlen_k_ptr
@@ -132,6 +154,8 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
                          stride_o,
                          0, // stride_randval
                          stride_do,
+                         stride_dq_acc,
+                         stride_dq,
                          stride_dk,
                          stride_dv,
                          0, // stride_dbias, FA without bias
@@ -143,6 +167,10 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
                          0, // nhead_stride_randval
                          nhead_stride_do,
                          nhead_stride_lse,
+                         nhead_stride_dq_acc,
+                         nhead_stride_dq,
+                         nhead_stride_dk,
+                         nhead_stride_dv,
                          0, // nhead_stride_dbias, FA without dbias
                          batch_stride_q,
                          batch_stride_k,
@@ -152,15 +180,17 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
                          0, // batch_stride_randval
                          batch_stride_do,
                          batch_stride_lse,
+                         batch_stride_dq_acc,
+                         batch_stride_dq,
                          batch_stride_dk,
                          batch_stride_dv,
                          0  , // batch_stride_dbias, FA without dbias
+                         split_stride_dq_acc,
                          mask.left,
                          mask.right,
                          static_cast<ck_tile::index_t>(mask.type),
                          p_dropout,
                          p_undrop,
-                         false, // s_randval
                          {drop_seed, drop_offset}};
 }
 
@@ -224,7 +254,7 @@ mha_bwd(const at::Tensor &dout,                   // batch_size x seqlen_q x num
     const int num_heads_k = k.size(2);
     TORCH_CHECK(batch_size > 0, "batch size must be positive");
     TORCH_CHECK(head_size_8x % 8 == 0, "head_size_8x should be a multiple of 8");
-    TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128");
+    TORCH_CHECK(head_size_8x <= 256, "CK FlashAttention backward only supports head dimension at most 256");
     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
 
     auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
@@ -296,7 +326,15 @@ mha_bwd(const at::Tensor &dout,                   // batch_size x seqlen_q x num
 
     auto opts = q.options();
     auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
-    // TODO - CK does not support dq_accum
+    at::Tensor dq_accum;
+
+    if (!deterministic) {
+        dq_accum = torch::zeros({1, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat));
+    } else {
+        const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64;
+        const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0);
+        dq_accum = torch::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat));
+    }
 
     at::Tensor dk_expanded, dv_expanded;
     if (num_heads_k != num_heads) {  // MQA / GQA
@@ -326,10 +364,9 @@ mha_bwd(const at::Tensor &dout,                   // batch_size x seqlen_q x num
 
     if (seqlen_q > 0) {
         ck_tile::stream_config stream_config{stream};
-        dq.zero_(); // ck use atomic operation on dq
 
         auto traits =
-            get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value());
+            get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value(), deterministic);
 
         auto args =
             get_ck_fmha_bwd_args(
@@ -347,6 +384,7 @@ mha_bwd(const at::Tensor &dout,                   // batch_size x seqlen_q x num
                 out,
                 softmax_lse,
                 dout_padded,
+                dq_accum,
                 softmax_d,
                 dq,
                 dk_expanded,
@@ -356,7 +394,8 @@ mha_bwd(const at::Tensor &dout,                   // batch_size x seqlen_q x num
                 drop_seed,
                 drop_offset);
 
-        fmha_bwd(traits, args, stream_config);
+        float t = fmha_bwd(traits, args, stream_config);
+        TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd");
     } else {
         // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
         dk_expanded.zero_();

+ 9 - 12
csrc/flash_attn_ck/mha_fwd.cpp

@@ -96,8 +96,6 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
                          v.data_ptr(),
                          alibi_slopes_ptr, // bias
                          has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
-                         nullptr, // lse_acc
-                         nullptr, // o_acc
                          has_lse ? softmax_lse.data_ptr() : nullptr,
                          out.data_ptr(),
                          nullptr, // seqstart_q
@@ -111,7 +109,6 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
                          d,             // hdim_v
                          h,             // nhead
                          h_k,           // nhead_k
-                         1,             // num_splits
                          softmax_scale, // scale_s
                          1,             // scale_p
                          1,             // scale_o
@@ -120,7 +117,6 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
                          stride_v,
                          stride_alibi_slopes,
                          stride_randval,
-                         0, // stride_o_acc,
                          stride_o,
                          nhead_stride_q,
                          nhead_stride_k,
@@ -128,8 +124,6 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
                          0, // nhead_stride_bias, FA without bias
                          nhead_stride_randval,
                          nhead_stride_lse,
-                         0, // nhead_stride_lse_acc
-                         0, // nhead_stride_o_acc
                          nhead_stride_o,
                          batch_stride_q,
                          batch_stride_k,
@@ -137,11 +131,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
                          0, // batch_stride_bias, FA without bias
                          batch_stride_randval,
                          batch_stride_lse,
-                         0, // batch_stride_lse_acc
-                         0, // batch_stride_o_acc
                          batch_stride_o,
-                         0, // split_stride_lse_acc
-                         0, // split_stride_o_acc
                          mask.left,
                          mask.right,
                          static_cast<ck_tile::index_t>(mask.type),
@@ -299,7 +289,13 @@ mha_fwd(at::Tensor &q,                            // batch_size x seqlen_q x num
         ck_tile::stream_config stream_config{stream};
 
         auto traits =
-            get_ck_fmha_fwd_traits(mask, q_dtype_str, head_size_8x, has_dropout, has_lse, alibi_slopes_.has_value());
+            get_ck_fmha_fwd_traits(
+                mask,
+                q_dtype_str,
+                head_size_8x,
+                has_dropout,
+                has_lse,
+                alibi_slopes_.has_value());
 
         auto args =
             get_ck_fmha_fwd_args(
@@ -324,7 +320,8 @@ mha_fwd(at::Tensor &q,                            // batch_size x seqlen_q x num
                 drop_seed,
                 drop_offset);
 
-        fmha_fwd(traits, args, stream_config);
+        float t = fmha_fwd(traits, args, stream_config);
+        TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");
     }
     else {
         // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.

+ 568 - 0
csrc/flash_attn_ck/mha_fwd_kvcache.cpp

@@ -0,0 +1,568 @@
+/******************************************************************************
+ * Copyright (c) 2024, Tri Dao.
+ ******************************************************************************/
+
+#include "flash_common.hpp"
+
+#include "fmha_fwd.hpp"
+#include "rotary.hpp"
+
+fmha_fwd_appendkv_traits get_ck_fmha_fwd_appendkv_traits(std::string dtype,
+                                                        int head_size,
+                                                        int rotary_dim,
+                                                        bool is_rotary_interleaved)
+{
+    rope_enum rope_type = (0 < rotary_dim ? (is_rotary_interleaved ? rope_enum::interleaved
+                                                                   : rope_enum::half_rotated)
+                                          : rope_enum::none);
+
+    return fmha_fwd_appendkv_traits{head_size,
+                                    head_size,
+                                    dtype,
+                                    true,  // is_v_rowmajor
+                                    rope_type};
+}
+
+fmha_fwd_splitkv_traits get_ck_fmha_fwd_splitkv_traits(const mask_info &mask,
+                                                       std::string dtype,
+                                                       int head_size,
+                                                       bool has_lse,
+                                                       bool enable_alibi)
+{
+    return fmha_fwd_splitkv_traits{head_size,
+                                   head_size,
+                                   dtype,
+                                   false, // is_group_mode
+                                   true, // is_v_rowmajor
+                                   mask.type,
+                                   enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
+                                   has_lse,
+                                   false}; // do_fp8_static_quant
+}
+
+fmha_fwd_appendkv_args get_ck_fmha_fwd_appendkv_args(const int b,
+                                                     const int seqlen_q,
+                                                     const int seqlen_knew,
+                                                     const int h,
+                                                     const int h_k,
+                                                     const int d,
+                                                     const int rotary_dim,
+                                                     const bool has_mask,
+                                                     const int page_block_size,
+                                                     // device pointers
+                                                     const at::Tensor q,
+                                                     const at::Tensor kcache,
+                                                     const at::Tensor vcache,
+                                                     const at::Tensor knew,
+                                                     const at::Tensor vnew,
+                                                     c10::optional<const at::Tensor> &seqlens_k_,
+                                                     c10::optional<const at::Tensor> &rotary_cos_,
+                                                     c10::optional<const at::Tensor> &rotary_sin_,
+                                                     c10::optional<const at::Tensor> &cache_batch_idx_,
+                                                     c10::optional<at::Tensor> &block_table_)
+{
+    // q: (batch_size, seqlen_q, nheads, d)
+    // kcache: (batch_size_c, seqlen_k, nheads_k, d) or (num_blocks, page_block_size, nheads_k, d)
+    // vcache: (batch_size_c, seqlen_k, nheads_k, d) or (num_blocks, page_block_size, nheads_k, d)
+    // knew: (batch_size, seqlen_knew, nheads_k, d)
+    // vnew: (batch_size, seqlen_knew, nheads_k, d)
+
+    // seqlens_k: (batch_size)
+    // rotary_cos: (seqlen_ro, rotary_dim / 2)
+    // rotary_sin: (seqlen_ro, rotary_dim / 2)
+    // block_table: (batch_size, max_num_blocks_per_seq)
+
+    fmha_fwd_appendkv_args args;
+    args.q_ptr = q.data_ptr();
+    args.k_ptr = kcache.data_ptr();
+    args.knew_ptr = knew.data_ptr();
+    args.v_ptr = vcache.data_ptr();
+    args.vnew_ptr = vnew.data_ptr();
+    args.seqlen_k_ptr = seqlens_k_.has_value() ? seqlens_k_.value().data_ptr() : nullptr;
+
+    args.seqlen_q = seqlen_q;
+    args.seqlen_knew = seqlen_knew;
+    args.batch = b;
+    args.hdim_q = d;
+    args.hdim_v = d;
+    args.nhead_q = h;
+    args.nhead_k = h_k;
+
+    args.rotary_cos_ptr = rotary_cos_.has_value() ? rotary_cos_.value().data_ptr() : nullptr;
+    args.rotary_sin_ptr = rotary_sin_.has_value() ? rotary_sin_.value().data_ptr() : nullptr;
+    args.rotary_dim = rotary_dim;
+    args.has_mask = has_mask;
+
+    if (block_table_.has_value())
+    {
+        auto block_table = block_table_.value();
+        args.block_table_ptr = block_table.data_ptr();
+        args.batch_stride_block_table = block_table.stride(0);
+        args.page_block_size = page_block_size;
+    }
+    else
+    {
+        args.block_table_ptr = nullptr;
+        args.batch_stride_block_table = 0;
+        args.page_block_size = 0;
+    }
+
+    args.cache_batch_idx = cache_batch_idx_.has_value() ?
+        reinterpret_cast<int *>(cache_batch_idx_.value().data_ptr()) : nullptr;
+
+    args.batch_stride_q = q.stride(0);
+    args.stride_q = q.stride(1);
+    args.nhead_stride_q = q.stride(2);
+
+    args.batch_stride_k = kcache.stride(0);
+    args.stride_k = kcache.stride(1);
+    args.nhead_stride_k = kcache.stride(2);
+
+    args.batch_stride_knew = knew.stride(0);
+    args.stride_knew = knew.stride(1);
+    args.nhead_stride_knew = knew.stride(2);
+
+    args.batch_stride_v = vcache.stride(0);
+    args.stride_v = vcache.stride(1);
+    args.nhead_stride_v = vcache.stride(2);
+
+    args.batch_stride_vnew = vnew.stride(0);
+    args.stride_vnew = vnew.stride(1);
+    args.nhead_stride_vnew = vnew.stride(2);
+
+    return args;
+}
+
+fmha_fwd_splitkv_args get_ck_fmha_fwd_splitkv_args(bool has_lse,
+                                                   const mask_info &mask,
+                                                   const int b,
+                                                   const int seqlen_q,
+                                                   const int seqlen_k,
+                                                   const int h,
+                                                   const int h_k,
+                                                   const int d,
+                                                   const int page_block_size,
+                                                   const int num_splits,
+                                                   float softmax_scale,
+                                                   // device pointers
+                                                   const at::Tensor q,
+                                                   const at::Tensor k,
+                                                   const at::Tensor v,
+                                                   const at::Tensor seqlens_k,
+                                                   c10::optional<const at::Tensor> &cache_batch_idx_,
+                                                   c10::optional<at::Tensor> &block_table_,
+                                                   c10::optional<at::Tensor> &alibi_slopes_,
+                                                   at::Tensor out,
+                                                   at::Tensor lse,
+                                                   at::Tensor lse_acc,
+                                                   at::Tensor out_acc)
+{
+    // q: (batch_size, seqlen_q, nheads, d)
+    // k: (batch_size, seqlen_k, nheads_k, d)
+    // v: (batch_size, seqlen_k, nheads_k, d)
+    // o: (batch_size, seqlen_q, nheads, d)
+
+    // alibi_slopes:(batch_size, nheads) or (nhead)
+    // lse: (batch_size, nheads, seqlen_q)
+    // lse_acc: (split, batch_size, nheads, seqlen_q)
+    // o_acc: (split, batch_size, nheads, seqlen_q, d)
+
+    fmha_fwd_splitkv_args args;
+    args.q_ptr = q.data_ptr();
+    args.k_ptr = k.data_ptr();
+    args.v_ptr = v.data_ptr();
+    args.bias_ptr = nullptr;
+    args.lse_acc_ptr = lse_acc.data_ptr();
+    args.o_acc_ptr = out_acc.data_ptr();
+    args.lse_ptr = nullptr;
+    args.o_ptr = out.data_ptr();
+
+    if (block_table_.has_value())
+    {
+        auto block_table = block_table_.value();
+        args.block_table_ptr = block_table.data_ptr();
+        args.batch_stride_block_table = block_table.stride(0);
+        args.page_block_size = page_block_size;
+    }
+    else
+    {
+        args.block_table_ptr = nullptr;
+        args.batch_stride_block_table = 0;
+        args.page_block_size = 0;
+    }
+
+    args.cache_batch_idx = cache_batch_idx_.has_value() ? cache_batch_idx_.value().data_ptr() : nullptr;
+
+    args.seqstart_q_ptr = nullptr;
+    args.seqstart_k_ptr = nullptr;
+    args.seqlen_k_ptr = seqlens_k.data_ptr();
+
+    args.seqlen_q = seqlen_q;
+    args.seqlen_k = seqlen_k;
+    args.batch = b;
+    args.max_seqlen_q = seqlen_q;
+    args.hdim_q = d;
+    args.hdim_v = d;
+    args.nhead_q = h;
+    args.nhead_k = h_k;
+    args.num_splits = num_splits;
+
+    args.scale_s = softmax_scale;
+    args.scale_p = 1;
+    args.scale_o = 1;
+
+    args.batch_stride_q = q.stride(0);
+    args.stride_q = q.stride(1);
+    args.nhead_stride_q = q.stride(2);
+
+    args.batch_stride_k = k.stride(0);
+    args.stride_k = k.stride(1);
+    args.nhead_stride_k = k.stride(2);
+
+    args.batch_stride_v = v.stride(0);
+    args.stride_v = v.stride(1);
+    args.nhead_stride_v = v.stride(2);
+
+    args.batch_stride_o = out.stride(0);
+    args.stride_o = out.stride(1);
+    args.nhead_stride_o = out.stride(2);
+
+    args.batch_stride_bias = 0;
+    args.stride_bias = 0;
+    args.nhead_stride_bias = 0;
+
+    args.batch_stride_lse = 0;
+    args.nhead_stride_lse = 0;
+
+    args.split_stride_lse_acc = lse_acc.stride(0);
+    args.batch_stride_lse_acc = lse_acc.stride(1);
+    args.nhead_stride_lse_acc = lse_acc.stride(2);
+
+    args.split_stride_o_acc = out_acc.stride(0);
+    args.batch_stride_o_acc = out_acc.stride(1);
+    args.nhead_stride_o_acc = out_acc.stride(2);
+    args.stride_o_acc = out_acc.stride(3);
+
+    if (has_lse) {
+        args.lse_ptr = lse.data_ptr();
+        args.batch_stride_lse = lse.stride(0);
+        args.nhead_stride_lse = lse.stride(1);
+    }
+
+    if (alibi_slopes_.has_value()) {
+        auto alibi_slopes = alibi_slopes_.value();
+        CHECK_DEVICE(alibi_slopes);
+        TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
+        TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
+        args.bias_ptr = alibi_slopes.data_ptr();
+        args.stride_bias = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
+    }
+
+    args.window_size_left = mask.left;
+    args.window_size_right = mask.right;
+    args.mask_type = static_cast<ck_tile::index_t>(mask.type);
+
+    return args;
+}
+
+std::vector<at::Tensor>
+mha_fwd_kvcache(at::Tensor &q,                                      // batch_size x seqlen_q x num_heads x head_size
+                const at::Tensor &kcache,                           // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
+                const at::Tensor &vcache,                           // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
+                c10::optional<const at::Tensor> &k_,                // batch_size x seqlen_knew x num_heads_k x head_size
+                c10::optional<const at::Tensor> &v_,                // batch_size x seqlen_knew x num_heads_k x head_size
+                c10::optional<const at::Tensor> &seqlens_k_,        // batch_size
+                c10::optional<const at::Tensor> &rotary_cos_,       // seqlen_ro x (rotary_dim / 2)
+                c10::optional<const at::Tensor> &rotary_sin_,       // seqlen_ro x (rotary_dim / 2)
+                c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
+                c10::optional<const at::Tensor> & /*leftpad_k_*/,   // batch_size
+                c10::optional<at::Tensor> &block_table_,            // batch_size x max_num_blocks_per_seq
+                c10::optional<at::Tensor> &alibi_slopes_,           // num_heads or batch_size x num_heads
+                c10::optional<at::Tensor> &out_,                    // batch_size x seqlen_q x num_heads x head_size
+                const float softmax_scale,
+                bool is_causal,
+                int window_size_left,
+                int window_size_right,
+                const float /*softcap*/,
+                bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
+                int num_splits)
+{
+    auto q_dtype = q.dtype();
+    TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
+                "FlashAttention only support fp16 and bf16 data type");
+
+    TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
+    TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype");
+    std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
+
+    CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
+
+    TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
+    TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
+    TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
+
+    at::Tensor block_table;
+    const bool paged_KV = block_table_.has_value();
+    if (paged_KV) {
+        TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx");
+        block_table = block_table_.value();
+        CHECK_DEVICE(block_table);
+        TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
+        TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
+    }
+
+    const auto sizes = q.sizes();
+
+    const int batch_size = sizes[0];
+    int seqlen_q = sizes[1];
+    int num_heads = sizes[2];
+    const int head_size_og = sizes[3];
+
+    const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
+    const int num_blocks = !paged_KV ? 0 : kcache.size(0);
+    const int page_block_size = !paged_KV ? 1 : kcache.size(1);
+    TORCH_CHECK(!paged_KV || page_block_size % 128 == 0, "Paged KV cache block size must be divisible by 128");
+    const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
+    const int num_heads_k = kcache.size(2);
+    const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
+    TORCH_CHECK(batch_size > 0, "batch size must be positive");
+    TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
+    TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
+
+    // causal=true is the same as causal=false in this case
+    if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
+    if (is_causal) { window_size_right = 0; }
+
+    mask_info mask;
+    if (is_causal) {
+        // Causal is the special case where window_size_right == 0 and window_size_left < 0.
+        window_size_right = 0;
+        std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
+        mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual
+    }
+    else if (window_size_left == -1 && window_size_right == -1) {
+        mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask
+    }
+    else {
+        // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
+        std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
+        mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local
+    }
+
+    // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
+    // H/t Daniel Haziza
+    const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
+    if (seqlenq_ngroups_swapped) {
+        const int ngroups = num_heads / num_heads_k;
+        q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
+        seqlen_q = ngroups;
+        num_heads = num_heads_k;
+    }
+
+    if (window_size_left >= seqlen_k) { window_size_left = -1; }
+    if (window_size_right >= seqlen_k) { window_size_right = -1; }
+
+    CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
+    if (!paged_KV) {
+        CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
+        CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
+    } else {
+        CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
+        CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
+        CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
+    }
+
+    at::Tensor q_padded, kcache_padded, vcache_padded;
+    if (head_size_og % 8 != 0) {
+        q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+        kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+        vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+    } else {
+        q_padded = q;
+        kcache_padded = kcache;
+        vcache_padded = vcache;
+    }
+
+    at::Tensor out;
+    if (out_.has_value()) {
+        out = out_.value();
+        TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
+        CHECK_DEVICE(out);
+        TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
+        CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
+        if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
+    } else {
+        out = torch::empty_like(q_padded);
+    }
+
+    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+    const int head_size_8x = round_multiple(head_size_og, 8);
+
+    // Otherwise the kernel will be launched from cuda:0 device
+    // Cast to char to avoid compiler warning about narrowing
+    at::cuda::CUDAGuard device_guard{(char)q.get_device()};
+
+    auto opts = q.options();
+
+    // TODO - check gradient, only training require lse
+    bool has_lse = true;
+    auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
+
+    int seqlen_knew = 0;
+    at::Tensor k, v, k_padded, v_padded;
+    if (k_.has_value()) {
+        TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in");
+        TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in");
+        TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache");
+        k = k_.value();
+        v = v_.value();
+        TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query");
+        TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query");
+        CHECK_DEVICE(k); CHECK_DEVICE(v);
+        TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
+        TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
+        seqlen_knew = k.size(1);
+        CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og);
+        CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og);
+        if (head_size_og % 8 != 0) {
+            k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+            v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+        } else {
+            k_padded = k;
+            v_padded = v;
+        }
+    }
+
+    if (seqlens_k_.has_value()) {
+        auto seqlens_k = seqlens_k_.value();
+        TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
+        CHECK_DEVICE(seqlens_k);
+        CHECK_CONTIGUOUS(seqlens_k);
+        CHECK_SHAPE(seqlens_k, batch_size);
+    }
+
+    int rotary_dim = 0;
+    if (rotary_cos_.has_value()) {
+        TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
+        auto rotary_cos = rotary_cos_.value();
+        CHECK_DEVICE(rotary_cos);
+        rotary_dim = rotary_cos.size(1) * 2;
+        TORCH_CHECK(rotary_dim <= head_size_og, "rotary_dim must be <= headdim");
+        TORCH_CHECK(rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
+        const int seqlen_ro = rotary_cos.size(0);
+        TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
+        CHECK_SHAPE(rotary_cos, seqlen_ro, rotary_dim / 2);
+        CHECK_CONTIGUOUS(rotary_cos);
+        TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
+
+        TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
+        auto rotary_sin = rotary_sin_.value();
+        CHECK_DEVICE(rotary_sin);
+        CHECK_SHAPE(rotary_sin, seqlen_ro, rotary_dim / 2);
+        CHECK_CONTIGUOUS(rotary_sin);
+        TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
+    }
+
+
+    if (cache_batch_idx_.has_value()) {
+        auto cache_batch_idx = cache_batch_idx_.value();
+        CHECK_DEVICE(cache_batch_idx);
+        CHECK_CONTIGUOUS(cache_batch_idx);
+        TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32");
+    }
+
+    num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, seqlen_q, head_size_8x, 0, num_splits);
+    TORCH_CHECK(num_splits > 0, "num_splits should greater than 0");
+    TORCH_CHECK(num_splits <= 128, "num_splits greater than 128 is not supported");
+
+    // Keep references to these tensors to extend their lifetime
+    auto softmax_lse_accum = torch::empty({num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
+    auto out_accum = torch::empty({num_splits, batch_size, num_heads, seqlen_q, head_size_8x}, opts.dtype(at::kFloat));
+
+    auto stream = at::cuda::getCurrentCUDAStream().stream();
+    ck_tile::stream_config stream_config{stream};
+
+    if (seqlen_knew > 0 || rotary_dim > 0) {
+        auto appendkv_traits =
+            get_ck_fmha_fwd_appendkv_traits(q_dtype_str, head_size_8x, rotary_dim, is_rotary_interleaved);
+
+        auto appendkv_args =
+            get_ck_fmha_fwd_appendkv_args(
+                batch_size,
+                seqlen_q,
+                seqlen_knew,
+                num_heads,
+                num_heads_k,
+                head_size_8x,
+                rotary_dim,
+                mask.type != mask_enum::no_mask,
+                page_block_size,
+                q_padded,
+                kcache_padded,
+                vcache_padded,
+                k_padded,
+                v_padded,
+                seqlens_k_,
+                rotary_cos_,
+                rotary_sin_,
+                cache_batch_idx_,
+                block_table_);
+
+        fmha_fwd_appendkv(appendkv_traits, appendkv_args, stream_config);
+    }
+
+    // seqlens_k_ is the seqlen of kvcache. We need to add seqlen_knew for before attention
+    auto append_seqlens_k = torch::empty({batch_size}, opts.dtype(torch::kInt32));
+    if (seqlens_k_.has_value())
+        append_seqlens_k = seqlens_k_.value() + seqlen_knew;
+    else
+        append_seqlens_k.fill_(seqlen_knew);
+
+    // we use splitkv even num_splits == 1, because fmha_fwd() does not support seqlen_k_ in batch mode
+    auto splitkv_traits =
+        get_ck_fmha_fwd_splitkv_traits(mask, q_dtype_str, head_size_8x, has_lse, alibi_slopes_.has_value());
+
+    auto splitkv_args =
+        get_ck_fmha_fwd_splitkv_args(
+            has_lse,
+            mask,
+            batch_size,
+            seqlen_q,
+            seqlen_k,
+            num_heads,
+            num_heads_k,
+            head_size_8x,
+            page_block_size,
+            num_splits,
+            softmax_scale,
+            q_padded,
+            kcache_padded,
+            vcache_padded,
+            append_seqlens_k,
+            cache_batch_idx_,
+            block_table_,
+            alibi_slopes_,
+            out,
+            softmax_lse,
+            softmax_lse_accum,
+            out_accum);
+
+    fmha_fwd_splitkv(splitkv_traits, splitkv_args, stream_config);
+
+    if (head_size_og % 8 != 0) {
+        out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
+        if (out_.has_value()) { out_.value().copy_(out); }
+        if (k_.has_value()) {
+            // It's expensive to copy the KV cache here for the case where head size not divisible by 8,
+            // but we don't expect to get this case in practice. This is just so that the code works for that case.
+            kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
+            vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
+        }
+    }
+
+    if (seqlenq_ngroups_swapped) {
+        out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
+        softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
+    }
+    return {out, softmax_lse};
+}

+ 75 - 36
csrc/flash_attn_ck/mha_varlen_bwd.cpp

@@ -11,7 +11,8 @@ fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask,
                                               std::string dtype,
                                               int head_size,
                                               bool has_dropout,
-                                              bool enable_alibi)
+                                              bool enable_alibi,
+                                              bool deterministic)
 {
     return fmha_bwd_traits{head_size,
                            head_size,
@@ -20,7 +21,9 @@ fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask,
                            mask.type,
                            enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
                            false,    // has_dbias
-                           has_dropout};
+                           has_dropout,
+                           false, // s_randval
+                           deterministic};
 }
 
 fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
@@ -41,6 +44,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
                                           const at::Tensor out,
                                           const at::Tensor softmax_lse,
                                           const at::Tensor dout,
+                                          at::Tensor dq_acc,
                                           at::Tensor d,
                                           at::Tensor dq,
                                           at::Tensor dk,
@@ -50,45 +54,62 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
                                           uint64_t drop_seed,
                                           uint64_t drop_offset)
 {
-    // q: (total_q, nheads, hdim)
-    // k: (total_k, nheads_k, hdim)
-    // v: (total_k, nheads_k, hdim)
-    // o: (total_q, nheads, hdim)
-    // dq: (total_q, nheads, hdim)
-    // dk_expanded: (total_k, nheads, hdim)
-    // dv_expanded: (total_k, nheads, hdim)
-    // do: (total_q, nheads, hdim)
-
-    // alibi_slopes:(batch_size, nheads) or (nhead)
-    // lse: (batch_size, nheads, max_seqlen_q)
-    // d: (batch_size, nheads, max_seqlen_q)
-
     ck_tile::index_t total_q = q.size(0);
     ck_tile::index_t total_k = k.size(0);
 
+    // q: (total_q, nheads, hdim)
+    ck_tile::index_t batch_stride_q = 0;
     ck_tile::index_t stride_q = q.stride(0);
-    ck_tile::index_t stride_k = k.stride(0);
-    ck_tile::index_t stride_v = v.stride(0);
-    ck_tile::index_t stride_o = out.stride(0);
-    ck_tile::index_t stride_do = dout.stride(0);
-    ck_tile::index_t stride_dk = dk.stride(0);
-    ck_tile::index_t stride_dv = dv.stride(0);
-
     ck_tile::index_t nhead_stride_q = q.stride(1);
-    ck_tile::index_t nhead_stride_k = k.stride(1);
-    ck_tile::index_t nhead_stride_v = v.stride(1);
-    ck_tile::index_t nhead_stride_o = out.stride(1);
-    ck_tile::index_t nhead_stride_do = dout.stride(1);
-    ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1);
 
-    ck_tile::index_t batch_stride_q = 0;
+    // k: (total_k, nheads_k, hdim)
     ck_tile::index_t batch_stride_k = 0;
+    ck_tile::index_t stride_k = k.stride(0);
+    ck_tile::index_t nhead_stride_k = k.stride(1);
+
+    // v: (total_k, nheads_k, hdim)
     ck_tile::index_t batch_stride_v = 0;
+    ck_tile::index_t stride_v = v.stride(0);
+    ck_tile::index_t nhead_stride_v = v.stride(1);
+
+    // o: (total_q, nheads, hdim)
     ck_tile::index_t batch_stride_o = 0;
+    ck_tile::index_t stride_o = out.stride(0);
+    ck_tile::index_t nhead_stride_o = out.stride(1);
+
+    // lse: (nheads, total_q)
+    ck_tile::index_t batch_stride_lse = 0;
+    ck_tile::index_t nhead_stride_lse = softmax_lse.stride(0);
+
+    // do: (total_q, nheads, hdim)
     ck_tile::index_t batch_stride_do = 0;
-    ck_tile::index_t batch_stride_lse = softmax_lse.stride(0);;
+    ck_tile::index_t stride_do = dout.stride(0);
+    ck_tile::index_t nhead_stride_do = dout.stride(1);
+
+    // d: (batch_size, nheads, max_seqlen_q)
+    // CK assume d share the same stride with lse
+
+    // dq: (total_q, nheads, hdim)
+    ck_tile::index_t batch_stride_dq = 0;
+    ck_tile::index_t stride_dq = dq.stride(0);
+    ck_tile::index_t nhead_stride_dq = dq.stride(1);
+
+
+    // dk_expanded: (total_k, nheads, hdim)
     ck_tile::index_t batch_stride_dk = 0;
+    ck_tile::index_t stride_dk = dk.stride(0);
+    ck_tile::index_t nhead_stride_dk = dk.stride(1);
+
+    // dv_expanded: (total_k, nheads, hdim)
     ck_tile::index_t batch_stride_dv = 0;
+    ck_tile::index_t stride_dv = dv.stride(0);
+    ck_tile::index_t nhead_stride_dv = dv.stride(1);
+
+    // dq_acc: (split, total_q, nheads, hdim)
+    ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0);
+    ck_tile::index_t batch_stride_dq_acc = 0;
+    ck_tile::index_t stride_dq_acc = dq_acc.stride(1);
+    ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(2);
 
     float p_undrop = 1.0 - p_dropout;
 
@@ -101,6 +122,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
         TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
         TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
         alibi_slopes_ptr = alibi_slopes.data_ptr();
+        // alibi_slopes:(batch_size, nheads) or (nhead)
         stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
     }
 
@@ -117,6 +139,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
                          dk.data_ptr(),
                          dv.data_ptr(),
                          nullptr, // dbias
+                         dq_acc.data_ptr(), // dq_acc
                          seqlens_q.data_ptr(), // seqstart_q
                          seqlens_k.data_ptr(), // seqstart_k
                          nullptr, // seqlen_k_ptr
@@ -137,6 +160,8 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
                          stride_o,
                          0, // stride_randval
                          stride_do,
+                         stride_dq_acc,
+                         stride_dq,
                          stride_dk,
                          stride_dv,
                          0, // stride_dbias, FA without bias
@@ -148,6 +173,10 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
                          0, // nhead_stride_randval
                          nhead_stride_do,
                          nhead_stride_lse,
+                         nhead_stride_dq_acc,
+                         nhead_stride_dq,
+                         nhead_stride_dk,
+                         nhead_stride_dv,
                          0, // nhead_stride_dbias, FA without dbias
                          batch_stride_q,
                          batch_stride_k,
@@ -157,15 +186,17 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
                          0, // batch_stride_randval
                          batch_stride_do,
                          batch_stride_lse,
+                         batch_stride_dq_acc,
+                         batch_stride_dq,
                          batch_stride_dk,
                          batch_stride_dv,
                          0  , // batch_stride_dbias, FA without dbias
+                         split_stride_dq_acc,
                          mask.left,
                          mask.right,
                          static_cast<ck_tile::index_t>(mask.type),
                          p_dropout,
                          p_undrop,
-                         false, // s_randval
                          {drop_seed, drop_offset}};
 }
 
@@ -239,7 +270,7 @@ mha_varlen_bwd(const at::Tensor &dout,                   // total_q x num_heads
     const int num_heads_k = k.size(1);
     TORCH_CHECK(batch_size > 0, "batch size must be positive");
     TORCH_CHECK(head_size_8x % 8 == 0, "head_size should be a multiple of 8");
-    TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128");
+    TORCH_CHECK(head_size_8x <= 256, "CK FlashAttention backward only supports head dimension at most 256");
     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
 
     auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
@@ -308,13 +339,20 @@ mha_varlen_bwd(const at::Tensor &dout,                   // total_q x num_heads
         dout_padded = dout;
     }
 
-
     // Cast to char to avoid compiler warning about narrowing
     at::cuda::CUDAGuard device_guard{(char)q.get_device()};
 
     auto opts = q.options();
     auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
-    // TODO - CK does not support dq_accum
+    at::Tensor dq_accum;
+
+    if (!deterministic) {
+        dq_accum = torch::zeros({1, total_q, num_heads, head_size_8x}, opts.dtype(at::kFloat));
+    } else {
+        const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64;
+        const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(max_seqlen_k, kN0);
+        dq_accum = torch::zeros({nsplits, total_q, num_heads, head_size_8x}, opts.dtype(at::kFloat));
+    }
 
     at::Tensor dk_expanded, dv_expanded;
     if (num_heads_k != num_heads) {  // MQA / GQA
@@ -351,10 +389,9 @@ mha_varlen_bwd(const at::Tensor &dout,                   // total_q x num_heads
 
     if (max_seqlen_q > 0) {
         ck_tile::stream_config stream_config{stream};
-        dq.zero_(); // ck use atomic operation on dq
 
         auto traits =
-            get_ck_fmha_varlen_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value());
+            get_ck_fmha_varlen_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value(), deterministic);
 
         auto args =
             get_ck_fmha_varlen_bwd_args(
@@ -374,6 +411,7 @@ mha_varlen_bwd(const at::Tensor &dout,                   // total_q x num_heads
                 out,
                 softmax_lse,
                 dout_padded,
+                dq_accum,
                 softmax_d,
                 dq,
                 dk_expanded,
@@ -383,7 +421,8 @@ mha_varlen_bwd(const at::Tensor &dout,                   // total_q x num_heads
                 drop_seed,
                 drop_offset);
 
-        fmha_bwd(traits, args, stream_config);
+        float t = fmha_bwd(traits, args, stream_config);
+        TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd");
     } else {
         // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
         dk_expanded.zero_();

+ 13 - 17
csrc/flash_attn_ck/mha_varlen_fwd.cpp

@@ -56,7 +56,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
     // o: (total_q, nheads, d)
 
     // alibi_slopes:(batch, nheads) or (nhead)
-    // lse: (batch, nheads, max_seqlen_q)
+    // lse: (nheads, total_q)
     // randval: (nheads, total_q, max_seqlen_k)
 
     ck_tile::index_t total_q = q.size(0);
@@ -72,15 +72,14 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
     ck_tile::index_t nhead_stride_k = k.stride(1);
     ck_tile::index_t nhead_stride_v = v.stride(1);
     ck_tile::index_t nhead_stride_o = out.stride(1);
-    ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0;
+    ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(0) : 0;
     ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0;
 
     ck_tile::index_t batch_stride_q = 0;
     ck_tile::index_t batch_stride_k = 0;
     ck_tile::index_t batch_stride_v = 0;
     ck_tile::index_t batch_stride_o = 0;
-
-    ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0;
+    ck_tile::index_t batch_stride_lse = 0;
     ck_tile::index_t batch_stride_randval = 0;
 
     void *alibi_slopes_ptr = nullptr;
@@ -100,8 +99,6 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
                          v.data_ptr(),
                          alibi_slopes_ptr, // bias
                          has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
-                         nullptr, // lse_acc
-                         nullptr, // o_acc
                          has_lse ? softmax_lse.data_ptr() : nullptr,
                          out.data_ptr(),
                          seqlens_q.data_ptr(), // seqstart_q
@@ -115,7 +112,6 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
                          d,             // hdim_v
                          h,             // nhead
                          h_k,           // nhead_k
-                         1,             // num_splits
                          softmax_scale, // scale_s
                          1,             // scale_p
                          1,             // scale_o
@@ -124,7 +120,6 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
                          stride_v,
                          stride_alibi_slopes,
                          stride_randval,
-                         0, // stride_o_acc,
                          stride_o,
                          nhead_stride_q,
                          nhead_stride_k,
@@ -132,8 +127,6 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
                          0, // nhead_stride_bias, FA without bias
                          nhead_stride_randval,
                          nhead_stride_lse,
-                         0, // nhead_stride_lse_acc
-                         0, // nhead_stride_o_acc
                          nhead_stride_o,
                          batch_stride_q,
                          batch_stride_k,
@@ -141,11 +134,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
                          0, // batch_stride_bias, FA without bias
                          batch_stride_randval,
                          batch_stride_lse,
-                         0, // batch_stride_lse_acc
-                         0, // batch_stride_o_acc
                          batch_stride_o,
-                         0, // split_stride_lse_acc
-                         0, // split_stride_o_acc
                          mask.left,
                          mask.right,
                          static_cast<ck_tile::index_t>(mask.type),
@@ -290,7 +279,7 @@ mha_varlen_fwd(at::Tensor &q,                   // total_q x num_heads x head_si
 
     at::Tensor softmax_lse;
     // TODO - check gradient, only training require lse
-    softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(torch::kFloat32));
+    softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(torch::kFloat32));
 
     at::Tensor p;
     if (return_dropout_randval) {
@@ -327,7 +316,13 @@ mha_varlen_fwd(at::Tensor &q,                   // total_q x num_heads x head_si
         ck_tile::stream_config stream_config{stream};
 
         auto traits =
-            get_ck_fmha_varlen_fwd_traits(mask, q_dtype_str, head_size_8x, has_dropout, has_lse, alibi_slopes_.has_value());
+            get_ck_fmha_varlen_fwd_traits(
+                mask,
+                q_dtype_str,
+                head_size_8x,
+                has_dropout,
+                has_lse,
+                alibi_slopes_.has_value());
 
         auto args =
             get_ck_fmha_varlen_fwd_args(
@@ -353,7 +348,8 @@ mha_varlen_fwd(at::Tensor &q,                   // total_q x num_heads x head_si
                 drop_seed,
                 drop_offset);
 
-        fmha_fwd(traits, args, stream_config);
+        float t = fmha_fwd(traits, args, stream_config);
+        TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");
     }
     else {
         // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.

+ 32 - 14
setup.py

@@ -87,6 +87,10 @@ def get_cuda_bare_metal_version(cuda_dir):
     return raw_output, bare_metal_version
 
 
+def get_hip_version():
+    return parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+'))
+
+
 def check_if_cuda_home_none(global_option: str) -> None:
     if CUDA_HOME is not None:
         return
@@ -307,6 +311,8 @@ elif not SKIP_CUDA_BUILD and IS_ROCM:
         os.makedirs("build")
 
     os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --output_dir build --receipt 2")
+    os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_appendkv --output_dir build --receipt 2")
+    os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_splitkv --output_dir build --receipt 2")
     os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d bwd --output_dir build --receipt 2")
 
     print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
@@ -321,8 +327,6 @@ elif not SKIP_CUDA_BUILD and IS_ROCM:
         generator_flag = ["-DOLD_GENERATOR_PATH"]
 
     check_if_rocm_home_none("flash_attn")
-    cc_flag = []
-
     archs = os.getenv("GPU_ARCHS", "native").split(";")
     validate_and_update_archs(archs)
 
@@ -335,7 +339,9 @@ elif not SKIP_CUDA_BUILD and IS_ROCM:
         torch._C._GLIBCXX_USE_CXX11_ABI = True
 
     sources = ["csrc/flash_attn_ck/flash_api.cpp",
+               "csrc/flash_attn_ck/flash_common.cpp",
                "csrc/flash_attn_ck/mha_bwd.cpp",
+               "csrc/flash_attn_ck/mha_fwd_kvcache.cpp",
                "csrc/flash_attn_ck/mha_fwd.cpp",
                "csrc/flash_attn_ck/mha_varlen_bwd.cpp",
                "csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob(
@@ -345,16 +351,14 @@ elif not SKIP_CUDA_BUILD and IS_ROCM:
     rename_cpp_to_cu(sources)
 
     renamed_sources = ["csrc/flash_attn_ck/flash_api.cu",
+                       "csrc/flash_attn_ck/flash_common.cu",
                        "csrc/flash_attn_ck/mha_bwd.cu",
+                       "csrc/flash_attn_ck/mha_fwd_kvcache.cu",
                        "csrc/flash_attn_ck/mha_fwd.cu",
                        "csrc/flash_attn_ck/mha_varlen_bwd.cu",
                        "csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu")
-    extra_compile_args = {
-        "cxx": ["-O3", "-std=c++17"] + generator_flag,
-        "nvcc":
-            [
-                "-O3","-std=c++17",
-                "-mllvm", "-enable-post-misched=0",
+
+    cc_flag += ["-O3","-std=c++17",
                 "-DCK_TILE_FMHA_FWD_FAST_EXP2=1",
                 "-fgpu-flush-denormals-to-zero",
                 "-DCK_ENABLE_BF16",
@@ -366,12 +370,26 @@ elif not SKIP_CUDA_BUILD and IS_ROCM:
                 "-DCK_ENABLE_INT8",
                 "-DCK_USE_XDL",
                 "-DUSE_PROF_API=1",
-                "-D__HIP_PLATFORM_HCC__=1",
                 # "-DFLASHATTENTION_DISABLE_BACKWARD",
-            ]
-            + generator_flag
-            + cc_flag
-        ,
+                "-D__HIP_PLATFORM_HCC__=1"]
+
+    cc_flag += [f"-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT={os.environ.get('CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT', 3)}"]
+
+    # Imitate https://github.com/ROCm/composable_kernel/blob/c8b6b64240e840a7decf76dfaa13c37da5294c4a/CMakeLists.txt#L190-L214
+    hip_version = get_hip_version()
+    if hip_version > Version('5.7.23302'):
+        cc_flag += ["-fno-offload-uniform-block"]
+    if hip_version > Version('6.1.40090'):
+        cc_flag += ["-mllvm", "-enable-post-misched=0"]
+    if hip_version > Version('6.2.41132'):
+        cc_flag += ["-mllvm", "-amdgpu-early-inline-all=true",
+                    "-mllvm", "-amdgpu-function-calls=false"]
+    if hip_version > Version('6.2.41133') and hip_version < Version('6.3.00000'):
+        cc_flag += ["-mllvm", "-amdgpu-coerce-illegal-types=1"]
+
+    extra_compile_args = {
+        "cxx": ["-O3", "-std=c++17"] + generator_flag,
+        "nvcc": cc_flag + generator_flag,
     }
 
     include_dirs = [
@@ -410,7 +428,7 @@ def get_wheel_url():
     cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
 
     if IS_ROCM:
-        torch_hip_version = parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+'))
+        torch_hip_version = get_hip_version()
         hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}"
         wheel_filename = f"{PACKAGE_NAME}-{flash_version}+rocm{hip_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
     else:

+ 875 - 9
tests/test_flash_attn_ck.py

@@ -11,6 +11,7 @@ from flash_attn import (
     flash_attn_varlen_func,
     flash_attn_varlen_kvpacked_func,
     flash_attn_varlen_qkvpacked_func,
+    flash_attn_with_kvcache,
 )
 
 from test_flash_attn import (
@@ -18,20 +19,23 @@ from test_flash_attn import (
     convert_flash_attn_S_to_softmax,
     generate_qkv,
     generate_random_padding_mask,
+    _generate_block_kvcache,
     attention_ref,
     attention_kvpacked_ref,
     attention_qkvpacked_ref,
 )
 
+from flash_attn.layers.rotary import apply_rotary_emb
+
 def is_bwd_hdim_supported(d):
-    return d <= 128 and d % 2 == 0
+    return d <= 256
 
 
 def ck_randval_to_dropout_mask(randval, p):
     # If p = 0.3, randval in 255 * (0.7, 1.0] will be dropout
     # randval in 255 * [0, 0.7] will be kept
     # If return dropout_mask >=0, value will be kept
-    return torch.floor(255.0 * (1 - p) - randval)
+    return math.floor(255.0 * (1 - p)) - randval.to(torch.float32)
 
 
 def pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen_q_rounded, seqlen_k_rounded):
@@ -59,7 +63,7 @@ def pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen_q_round
 
 
 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
-@pytest.mark.parametrize("deterministic", [False])
+@pytest.mark.parametrize("deterministic", [False, True])
 @pytest.mark.parametrize("alibi", [False, True])
 @pytest.mark.parametrize("local", [False, True])
 @pytest.mark.parametrize("causal", [False, True])
@@ -152,12 +156,12 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
         print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
         print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
 
-        # TODO - use 10 times to check, wait for ck to change dq type to f32
+        # TODO - use 10 times to check, wait for ck to fix bwd precision issue
         assert (dqkv - dqkv_ref).abs().max().item() <= 10 * (dqkv_pt - dqkv_ref).abs().max().item()
 
 
 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
-@pytest.mark.parametrize("deterministic", [False])
+@pytest.mark.parametrize("deterministic", [False, True])
 @pytest.mark.parametrize("alibi", [False, True])
 @pytest.mark.parametrize("local", [False, True])
 @pytest.mark.parametrize("causal", [False, True])
@@ -270,14 +274,14 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi,
         print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
         print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
 
-        # TODO - use 10 times to check, wait for ck to change dq type to f32
+        # TODO - use 10 times to check, wait for ck to fix bwd precision issue
         assert (dqkv - dqkv_ref).abs().max().item() <= 10 * (dqkv_pt - dqkv_ref).abs().max().item()
 
 
 @pytest.mark.parametrize("kvpacked", [True, False])
 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
 @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
-@pytest.mark.parametrize("deterministic", [False])
+@pytest.mark.parametrize("deterministic", [False, True])
 @pytest.mark.parametrize("alibi", [False, True])
 @pytest.mark.parametrize("local", [False, True])
 @pytest.mark.parametrize("causal", [False, True])
@@ -484,7 +488,7 @@ def test_flash_attn_output(
         print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
         print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
 
-        # TODO - use 10 times to check, wait for ck to change dq type to f32
+        # TODO - use 10 times to check, wait for ck to fix bwd precision issue
         assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item()
         assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item()
         assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item()
@@ -748,7 +752,869 @@ def test_flash_attn_varlen_output(
         print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
         print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
 
-        # TODO - use 10 times to check, wait for ck to change dq type to f32
+        # TODO - use 10 times to check, wait for ck to fix bwd precision issue
         assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item()
         assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item()
         assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item()
+
+
+@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
+@pytest.mark.parametrize("local", [False, True])
+@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
+@pytest.mark.parametrize("swap_sq_sk", [False, True])
+@pytest.mark.parametrize(
+    "seqlen_q,seqlen_k",
+    [
+        # (1, 239),
+        (3, 799),
+        (127, 512),
+        (127, 513),
+        (113, 203),
+        (128, 217),
+        (113, 211),
+        (108, 256),
+        (256, 512),
+        (1023, 1024),
+    ],
+)
+def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
+    if max(seqlen_q, seqlen_k) >= 2048:
+        pytest.skip()
+    if swap_sq_sk:
+        seqlen_q, seqlen_k = seqlen_k, seqlen_q
+    device = "cuda"
+    causal = True
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 8
+    nheads = 9
+    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
+    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size)
+    out_ref, attn_ref = attention_ref(
+        q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size
+    )
+    out_pt, attn_pt = attention_ref(
+        q,
+        k,
+        v,
+        None,
+        None,
+        None,
+        0.0,
+        None,
+        causal=causal,
+        window_size=window_size,
+        upcast=False,
+        reorder_ops=True,
+    )
+
+    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
+    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
+    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
+    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
+
+    # Check that FlashAttention's numerical error is at most 4 times the numerical error
+    # of a Pytorch implementation.
+    assert (out - out_ref).abs().max().item() <= 4 * (out_pt - out_ref).abs().max().item() + 1e-5
+
+    g = torch.randn_like(out)
+    if is_bwd_hdim_supported(d):
+        do_o = (g.float() * out.float()).sum(-1)
+        (
+            dq,
+            dk,
+            dv,
+        ) = torch.autograd.grad(out, (q, k, v), g)
+        (
+            dq_ref,
+            dk_ref,
+            dv_ref,
+        ) = torch.autograd.grad(out_ref, (q, k, v), g)
+        (
+            dq_pt,
+            dk_pt,
+            dv_pt,
+        ) = torch.autograd.grad(out_pt, (q, k, v), g)
+        print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
+        print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
+        print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
+        print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
+        print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
+        print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
+        print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
+        print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
+        print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
+        print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
+        print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
+        print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
+
+    # TODO - use 10 times to check, wait for ck to fix bwd precision issue
+    assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item() + 1e-4
+    assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item() + 1e-4
+    assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item() + 1e-4
+
+
+@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
+@pytest.mark.parametrize("local", [False, True])
+@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
+@pytest.mark.parametrize("swap_sq_sk", [False, True])
+@pytest.mark.parametrize(
+    "seqlen_q,seqlen_k",
+    [
+        # (1, 239),
+        (3, 799),
+        (127, 512),
+        (127, 513),
+        (113, 203),
+        (128, 217),
+        (113, 211),
+        (108, 256),
+        (256, 512),
+        (1023, 1024),
+    ],
+)
+# TODO: Support paged_kv_block
+# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
+@pytest.mark.parametrize("paged_kv_block_size", [None])
+def test_flash_attn_varlen_causal(
+    seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype
+):
+    if max(seqlen_q, seqlen_k) >= 2048:
+        pytest.skip()
+    if swap_sq_sk:
+        seqlen_q, seqlen_k = seqlen_k, seqlen_q
+    device = "cuda"
+    causal = True
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 8
+    nheads = 9
+    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
+    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
+
+    if paged_kv_block_size is None:
+        k = torch.randn(
+            batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
+        )
+        v = torch.randn(
+            batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
+        )
+        block_table = None
+    else:
+        k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache(
+            seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype
+        )
+    query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
+    key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
+    (
+        q_unpad,
+        k_unpad,
+        v_unpad,
+        cu_seqlens_q,
+        cu_seqlens_k,
+        max_seqlen_q,
+        max_seqlen_k,
+        q,
+        k,
+        v,
+        output_pad_fn,
+        dq_pad_fn,
+        dk_pad_fn,
+    ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
+    out_unpad = flash_attn_varlen_func(
+        q_unpad,
+        k_unpad if paged_kv_block_size is None else k_cache_paged,
+        v_unpad if paged_kv_block_size is None else v_cache_paged,
+        cu_seqlens_q,
+        cu_seqlens_k,
+        max_seqlen_q,
+        max_seqlen_k,
+        0.0,
+        causal=causal,
+        window_size=window_size,
+        block_table=block_table,
+    )
+    out = output_pad_fn(out_unpad)
+    out_ref, attn_ref = attention_ref(
+        q,
+        k,
+        v,
+        query_padding_mask,
+        key_padding_mask,
+        None,
+        0.0,
+        None,
+        causal=causal,
+        window_size=window_size,
+    )
+    out_pt, attn_pt = attention_ref(
+        q,
+        k,
+        v,
+        query_padding_mask,
+        key_padding_mask,
+        None,
+        0.0,
+        None,
+        causal=causal,
+        window_size=window_size,
+        upcast=False,
+        reorder_ops=True,
+    )
+
+    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
+    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
+    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
+    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
+
+    # Check that FlashAttention's numerical error is at most twice the numerical error
+    # of a Pytorch implementation.
+    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
+
+    g = torch.randn_like(out)
+    if is_bwd_hdim_supported(d):
+        do_o = (g.float() * out.float()).sum(-1)
+        test_backward = block_table is None
+        if test_backward:
+            (
+                dq_unpad,
+                dk_unpad,
+                dv_unpad,
+            ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
+            dq = dq_pad_fn(dq_unpad)
+            dk = dk_pad_fn(dk_unpad)
+            dv = dk_pad_fn(dv_unpad)
+            (
+                dq_ref,
+                dk_ref,
+                dv_ref,
+            ) = torch.autograd.grad(out_ref, (q, k, v), g)
+            (
+                dq_pt,
+                dk_pt,
+                dv_pt,
+            ) = torch.autograd.grad(out_pt, (q, k, v), g)
+            print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
+            print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
+            print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
+            print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
+            print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
+            print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
+            print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
+            print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
+            print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
+            print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
+            print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
+            print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
+
+        if test_backward:
+            # TODO - use 10 times to check, wait for ck to fix bwd precision issue
+            assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item() + 1e-5
+            assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item() + 1e-5
+            assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item() + 1e-5
+
+
+# TODO - support splitkv
+# def test_flash_attn_splitkv
+
+
+# TODO - Support has_leftpad
+@pytest.mark.parametrize("dtype", [torch.float16])
+@pytest.mark.parametrize("num_splits", [1, 0])
+@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
+@pytest.mark.parametrize("new_kv", [False, True])
+@pytest.mark.parametrize("alibi", [False, True])
+@pytest.mark.parametrize("local", [False, True])
+@pytest.mark.parametrize("causal", [False, True])
+@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
+@pytest.mark.parametrize("rotary_interleaved", [False, True])
+@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
+@pytest.mark.parametrize("paged_kv_block_size", [None, 256])
+@pytest.mark.parametrize("has_leftpad", [False])
+@pytest.mark.parametrize("has_batch_idx", [False, True])
+@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
+@pytest.mark.parametrize(
+    "seqlen_q,seqlen_k",
+    [
+        (1, 128),
+        (1, 339),
+        (3, 1024),
+        (64, 800),
+        (64, 256),
+        (3, 799),
+        (64, 2048),
+        (16, 20000),
+        (1, 128 * 1024),
+        (16, 128 * 1024),
+        (128, 128),
+    ],
+)
+def test_flash_attn_kvcache(
+    seqlen_q,
+    seqlen_k,
+    d,
+    has_batch_idx,
+    has_leftpad,
+    paged_kv_block_size,
+    rotary_fraction,
+    rotary_interleaved,
+    seqlen_new_eq_seqlen_q,
+    causal,
+    local,
+    alibi,
+    new_kv,
+    mha_type,
+    num_splits,
+    dtype,
+):
+    if seqlen_q > seqlen_k and new_kv:
+        pytest.skip()
+    if not new_kv and rotary_fraction > 0.0:
+        pytest.skip()
+    if has_batch_idx and paged_kv_block_size is not None:
+        pytest.skip()
+    if has_leftpad and paged_kv_block_size is not None:
+        pytest.skip()
+    device = "cuda"
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 1
+    batch_size_cache = batch_size if not has_batch_idx else batch_size * 2
+    nheads = 6
+    # rotary_dim must be a multiple of 16, and must be <= d
+    rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
+    nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
+    assert nheads % nheads_k == 0
+    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
+    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
+    seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item()
+    if new_kv:
+        k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
+        v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype)
+    else:
+        k, v = None, None
+    if paged_kv_block_size is None:
+        k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
+        v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
+        block_table = None
+    else:
+        (
+            k_cache,
+            v_cache,
+            block_table,
+            k_cache_paged,
+            v_cache_paged,
+            num_blocks,
+        ) = _generate_block_kvcache(
+            seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype
+        )
+    cache_seqlens = torch.randint(
+        0 if new_kv else 1,
+        # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
+        (
+            (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)
+            if new_kv
+            else (seqlen_k + 1)
+        ),
+        (batch_size,),
+        dtype=torch.int32,
+        device=device,
+    )
+    if has_leftpad:
+        cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device)
+                                   if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device)
+                                   for i in range(batch_size)])
+    else:
+        cache_leftpad = None
+    arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
+    cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
+    key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
+    if has_leftpad:
+        key_padding_mask = torch.logical_and(
+            key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k)
+        )
+    if has_batch_idx:
+        cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[
+            :batch_size
+        ]
+    else:
+        cache_batch_idx = None
+    if alibi:
+        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
+        attn_bias = attn_bias_from_alibi_slopes(
+            alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad
+        )
+    else:
+        alibi_slopes, attn_bias = None, None
+    # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
+    if rotary_dim > 0:
+        angle = (
+            torch.rand(
+                seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size,
+                rotary_dim // 2,
+                device=device,
+            )
+            * 2
+            * math.pi
+        )
+        cos = torch.cos(angle).to(dtype=dtype)
+        sin = torch.sin(angle).to(dtype=dtype)
+        if causal or local:
+            q_ro = apply_rotary_emb(
+                q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
+            )
+        else:
+            q_ro = rearrange(
+                apply_rotary_emb(
+                    rearrange(q, "b s h d -> b 1 (s h) d"),
+                    cos,
+                    sin,
+                    seqlen_offsets=cache_seqlens,
+                    interleaved=rotary_interleaved,
+                ),
+                "b 1 (s h) d -> b s h d",
+                s=seqlen_q,
+            )
+        # q_ro = q
+        k_ro = apply_rotary_emb(
+            k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
+        )
+    else:
+        cos, sin = None, None
+        q_ro, k_ro = q, k
+    # k_cache[:, 64:] = -1
+    k_cache_ref = (
+        k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
+    ).clone()
+    v_cache_ref = (
+        v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
+    ).clone()
+    if new_kv:
+        update_mask = torch.logical_and(
+            cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
+        )
+        k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...")
+        v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...")
+    k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
+    v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
+    out = flash_attn_with_kvcache(
+        q,
+        k_cache if paged_kv_block_size is None else k_cache_paged,
+        v_cache if paged_kv_block_size is None else v_cache_paged,
+        k,
+        v,
+        rotary_cos=cos,
+        rotary_sin=sin,
+        cache_seqlens=cache_seqlens,
+        cache_batch_idx=cache_batch_idx,
+        cache_leftpad=cache_leftpad,
+        block_table=block_table,
+        causal=causal,
+        window_size=window_size,
+        rotary_interleaved=rotary_interleaved,
+        alibi_slopes=alibi_slopes,
+        num_splits=num_splits,
+    )
+    # out = flash_attn_with_kvcache(
+    #     q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size
+    # )
+    # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)
+    # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
+    # m = qk.amax(-1, keepdim=True)
+    # s_tmp = torch.exp((qk - m) / math.sqrt(d))
+    # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
+    # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
+    # probs = torch.softmax(qk, dim=-1)
+    out_ref, _ = attention_ref(
+        q_ro,
+        k_cache_rep,
+        v_cache_rep,
+        None,
+        key_padding_mask,
+        attn_bias,
+        0.0,
+        None,
+        causal=causal,
+        window_size=window_size,
+        key_leftpad=cache_leftpad,
+    )
+    out_pt, _ = attention_ref(
+        q_ro,
+        k_cache_rep,
+        v_cache_rep,
+        None,
+        key_padding_mask,
+        attn_bias,
+        0.0,
+        None,
+        causal=causal,
+        window_size=window_size,
+        upcast=False,
+        reorder_ops=True,
+        key_leftpad=cache_leftpad,
+    )
+    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
+    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
+    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
+    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
+
+    # Check that FlashAttention's numerical error is at most twice the numerical error
+    # of a Pytorch implementation.
+    if new_kv:
+        if paged_kv_block_size is None:
+            k_cache_select = (
+                k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
+            )
+            v_cache_select = (
+                v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
+            )
+        else:
+            k_cache_select = rearrange(
+                k_cache_paged[block_table.to(dtype=torch.long).flatten()],
+                "(b nblocks) block_size ... -> b (nblocks block_size) ...",
+                b=batch_size,
+            )[:, :seqlen_k]
+            v_cache_select = rearrange(
+                v_cache_paged[block_table.to(dtype=torch.long).flatten()],
+                "(b nblocks) block_size ... -> b (nblocks block_size) ...",
+                b=batch_size,
+            )[:, :seqlen_k]
+        assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)
+        assert torch.equal(v_cache_select, v_cache_ref)
+    # mult = 3 if f16, bf16 need 4
+    mult = 4 if not alibi else 5
+    assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5
+
+
+
+@pytest.mark.parametrize("dtype", [torch.float16])
+@pytest.mark.parametrize("causal", [False, True])
+@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
+@pytest.mark.parametrize(
+    "seqlen_q,seqlen_k",
+    [
+        (1, 239),
+        (239, 1),
+        (3, 799),
+        (799, 3),
+        (1024, 128),
+        (97, 97),
+        (128, 128),
+        (200, 200),
+        (256, 256),
+        (257, 257),
+        (384, 384),
+        (512, 512),
+        (768, 768),
+        # (1024, 1024),
+    ],
+)
+@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
+def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype):
+    device = "cuda"
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 60  # Sometimes we need large batch size for the race conditions to trigger
+    nheads = 4
+    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    torch.random.manual_seed(42)
+    out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)
+    g = torch.randn_like(out0)
+    if dropout_p == 0 and is_bwd_hdim_supported(d):
+        (
+            dq0,
+            dk0,
+            dv0,
+        ) = torch.autograd.grad(out0, (q, k, v), g)
+        # Numerical error if we just do any arithmetic on dq
+        dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item()
+
+    for i in range(250):
+        torch.random.manual_seed(42)
+        out, lse, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)
+        assert torch.equal(out, out0)
+        assert torch.equal(lse, lse0)
+
+        if dropout_p == 0:
+            (
+                dq,
+                dk,
+                dv,
+            ) = torch.autograd.grad(out, (q, k, v), g)
+            dq_equal = torch.allclose(dq, dq0, atol=dq_atol)
+            if not dq_equal:
+                print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}")
+
+            assert torch.equal(dv, dv0)
+            assert torch.equal(dk, dk0)
+            assert dq_equal
+
+
+@pytest.mark.parametrize("dtype", [torch.float16])
+@pytest.mark.parametrize("causal", [False, True])
+@pytest.mark.parametrize("d", [16, 32, 64])
+@pytest.mark.parametrize("seqlen", [1, 2, 5, 17, 128])
+def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype):
+    """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
+    in the case where seqlen % 128 != 0.
+    """
+
+    # TODO - 1 or 2 might fail, need to check
+    if seqlen == 1 or seqlen == 2:
+        pytest.skip()
+
+    device = "cuda"
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 2
+    nheads = 5
+    q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 5
+    k, v = [
+        torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 3
+        for _ in range(2)
+    ]
+    q.requires_grad_(True)
+    k.requires_grad_(True)
+    v.requires_grad_(True)
+    out = flash_attn_func(q, k, v, causal=causal)
+    g = torch.randn_like(out)
+    out.backward(g)
+    q_pt = q.detach().clone().requires_grad_(True)
+    k_pt = k.detach().clone().requires_grad_(True)
+    v_pt = v.detach().clone().requires_grad_(True)
+    out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)
+    out_pt.backward(g)
+    q_ref = q.detach().clone().requires_grad_(True)
+    k_ref = k.detach().clone().requires_grad_(True)
+    v_ref = v.detach().clone().requires_grad_(True)
+    out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)
+    out_ref.backward(g)
+    print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}")
+    print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}")
+    print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}")
+    print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}")
+    print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}")
+    print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}")
+    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
+    assert (q.grad - q_ref.grad).abs().max().item() <= 5 * (
+        q_pt.grad - q_ref.grad
+    ).abs().max().item() + 1e-3
+    assert (k.grad - k_ref.grad).abs().max().item() <= 5 * (
+        k_pt.grad - k_ref.grad
+    ).abs().max().item() + 1e-3
+    assert (v.grad - v_ref.grad).abs().max().item() <= 5 * (
+        v_pt.grad - v_ref.grad
+    ).abs().max().item() + 1e-3
+
+
+@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
+@pytest.mark.parametrize("causal", [False, True])
+@pytest.mark.parametrize("d", [64, 128])
+@pytest.mark.parametrize("seqlen", [97, 128, 200, 256])
+def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):
+    """We previously had a bug where we were using the wrong strides of dout, which shows up
+    when dout is not contiguous.
+    """
+    device = "cuda"
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 5
+    nheads = 2
+    q, k, v = [
+        torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda", requires_grad=True)
+        for _ in range(3)
+    ]
+    out = rearrange(flash_attn_func(q, k, v, causal=causal), "b s ... -> s b ...")
+    # So g is not contiguous
+    g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device="cuda")[:, ::2]
+    out.backward(g)
+    q_pt = q.detach().clone().requires_grad_(True)
+    k_pt = k.detach().clone().requires_grad_(True)
+    v_pt = v.detach().clone().requires_grad_(True)
+    out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)
+    out_pt = rearrange(out_pt, "b s ... -> s b ...")
+    out_pt.backward(g)
+    q_ref = q.detach().clone().requires_grad_(True)
+    k_ref = k.detach().clone().requires_grad_(True)
+    v_ref = v.detach().clone().requires_grad_(True)
+    out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)
+    out_ref = rearrange(out_ref, "b s ... -> s b ...")
+    out_ref.backward(g)
+    print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}")
+    print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}")
+    print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}")
+    print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}")
+    print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}")
+    print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}")
+    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
+    assert (q.grad - q_ref.grad).abs().max().item() <= 2 * (
+        q_pt.grad - q_ref.grad
+    ).abs().max().item()
+    assert (k.grad - k_ref.grad).abs().max().item() <= 2 * (
+        k_pt.grad - k_ref.grad
+    ).abs().max().item()
+    assert (v.grad - v_ref.grad).abs().max().item() <= 2 * (
+        v_pt.grad - v_ref.grad
+    ).abs().max().item()
+
+
+@pytest.mark.parametrize("dtype", [torch.float16])
+@pytest.mark.parametrize("causal", [False, True])
+@pytest.mark.parametrize("d", [16, 32, 64])
+def test_flash_attn_bwd_varlen_overflow(d, causal, dtype):
+    """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
+    in the case where seqlen % 128 != 0 or varlen.
+    """
+    device = "cuda"
+    # set seed
+    torch.random.manual_seed(0)
+    nheads = 5
+    q_cuseqlen = torch.tensor([0, 76, 110, 256], device=device, dtype=torch.int32)
+    k_cuseqlen = torch.tensor([0, 1, 2, 3], device=device, dtype=torch.int32)
+    Mq = 256
+    Mk = 3
+
+    q = torch.randn([Mq, nheads, d], dtype=dtype, device=device) * 3
+    k, v = [torch.randn([Mk, nheads, d], dtype=dtype, device=device) * 3 for _ in range(2)]
+    q.requires_grad_(True)
+    k.requires_grad_(True)
+    v.requires_grad_(True)
+
+    out = flash_attn_varlen_func(q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal)
+    g = torch.randn_like(out)
+    out.backward(g)
+
+    assert not q.grad.isnan().any()
+    assert not k.grad.isnan().any()
+    assert not v.grad.isnan().any()
+
+
+@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
+@pytest.mark.parametrize("local", [False, True])
+@pytest.mark.parametrize("causal", [False, True])
+@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
+@pytest.mark.parametrize("swap_sq_sk", [False, True])
+@pytest.mark.parametrize(
+    "seqlen_q,seqlen_k",
+    [
+        (1, 239),
+        (3, 799),
+        (127, 512),
+        (127, 513),
+        (113, 203),
+        (128, 217),
+        (113, 211),
+        (108, 256),
+        (256, 512),
+        (1023, 1024),
+    ],
+)
+def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):
+    if (
+        max(seqlen_q, seqlen_k) >= 2048
+        and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
+    ):
+        pytest.skip()  # Reference implementation OOM
+    if swap_sq_sk:
+        seqlen_q, seqlen_k = seqlen_k, seqlen_q
+    device = "cuda"
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 4
+    nheads = 9
+    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
+    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True)
+
+    g = torch.randn_like(out)
+    dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
+    for _ in range(50):
+        dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
+        assert torch.equal(dv, dv0)
+        assert torch.equal(dk, dk0)
+        assert torch.equal(dq, dq0)
+
+
+@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
+@pytest.mark.parametrize("local", [False, True])
+@pytest.mark.parametrize("causal", [False, True])
+@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
+@pytest.mark.parametrize("swap_sq_sk", [False, True])
+@pytest.mark.parametrize(
+    "seqlen_q,seqlen_k",
+    [
+        (1, 239),
+        (3, 799),
+        (127, 512),
+        (127, 513),
+        (113, 203),
+        (128, 217),
+        (113, 211),
+        (108, 256),
+        (256, 512),
+        (1023, 1024),
+    ],
+)
+def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):
+    if (
+        max(seqlen_q, seqlen_k) >= 2048
+        and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
+    ):
+        pytest.skip()  # Reference implementation OOM
+    if swap_sq_sk:
+        seqlen_q, seqlen_k = seqlen_k, seqlen_q
+    device = "cuda"
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 2
+    nheads = 9
+    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
+    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
+    key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
+    (
+        q_unpad,
+        k_unpad,
+        v_unpad,
+        cu_seqlens_q,
+        cu_seqlens_k,
+        max_seqlen_q,
+        max_seqlen_k,
+        q,
+        k,
+        v,
+        output_pad_fn,
+        dq_pad_fn,
+        dk_pad_fn,
+    ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
+    out = flash_attn_varlen_func(
+        q_unpad,
+        k_unpad,
+        v_unpad,
+        cu_seqlens_q,
+        cu_seqlens_k,
+        max_seqlen_q,
+        max_seqlen_k,
+        0.0,
+        causal=causal,
+        window_size=window_size,
+        deterministic=True,
+    )
+
+    g = torch.randn_like(out)
+    dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
+    for _ in range(50):
+        dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
+        assert torch.equal(dv, dv0)
+        assert torch.equal(dk, dk0)
+        assert torch.equal(dq, dq0)
+