浏览代码

Implement rotary embedding in flash_attn_with_kvcache

Tri Dao 1 年之前
父节点
当前提交
ccbb14f38e

+ 58 - 42
csrc/flash_attn/flash_api.cpp

@@ -13,7 +13,9 @@
 #include "flash.h"
 #include "static_switch.h"
 
+#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
 #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
 
 
 void set_params_fprop(Flash_fwd_params &params,
@@ -260,9 +262,7 @@ mha_fwd(const at::Tensor &q,         // batch_size x seqlen_q x num_heads x head
     TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
     TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
 
-    TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
-    TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device");
-    TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device");
+    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
 
     TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
     TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
@@ -299,7 +299,7 @@ mha_fwd(const at::Tensor &q,         // batch_size x seqlen_q x num_heads x head
     if (out_.has_value()) {
         out = out_.value();
         TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
-        TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device");
+        CHECK_DEVICE(out);
         TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
         CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
         if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
@@ -426,17 +426,15 @@ mha_varlen_fwd(const at::Tensor &q,  // total_q x num_heads x head_size, total_q
     TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
     TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
 
-    TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
-    TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device");
-    TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device");
-    TORCH_CHECK(cu_seqlens_q.is_cuda(), "cu_seqlens_q must be on CUDA device");
-    TORCH_CHECK(cu_seqlens_k.is_cuda(), "cu_seqlens_k must be on CUDA device");
+    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
+    CHECK_DEVICE(cu_seqlens_q);
+    CHECK_DEVICE(cu_seqlens_k);
 
     TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
     TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
     TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
-    TORCH_CHECK(cu_seqlens_q.is_contiguous(), "cu_seqlens_q must be contiguous");
-    TORCH_CHECK(cu_seqlens_k.is_contiguous(), "cu_seqlens_k must be contiguous");
+    CHECK_CONTIGUOUS(cu_seqlens_q);
+    CHECK_CONTIGUOUS(cu_seqlens_k);
 
     const auto sizes = q.sizes();
 
@@ -471,7 +469,7 @@ mha_varlen_fwd(const at::Tensor &q,  // total_q x num_heads x head_size, total_q
     if (out_.has_value()) {
         out = out_.value();
         TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
-        TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device");
+        CHECK_DEVICE(out);
         TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
         CHECK_SHAPE(out, total_q, num_heads, head_size_og);
         if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
@@ -610,12 +608,8 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
     TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
     TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
 
-    TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
-    TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device");
-    TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device");
-    TORCH_CHECK(out.is_cuda(), "out tensor must be on CUDA device");
-    TORCH_CHECK(dout.is_cuda(), "dout tensor must be on CUDA device");
-    TORCH_CHECK(softmax_lse.is_cuda(), "softmax_lse tensor must be on CUDA device");
+    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
+    CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
 
     TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
     TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
@@ -657,7 +651,7 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
     if (dq_.has_value()) {
         dq = dq_.value();
         TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
-        TORCH_CHECK(dq.is_cuda(), "dq must be on CUDA device");
+        CHECK_DEVICE(dq);
         TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
         CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
     } else {
@@ -666,7 +660,7 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
     if (dk_.has_value()) {
         dk = dk_.value();
         TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
-        TORCH_CHECK(dk.is_cuda(), "dk must be on CUDA device");
+        CHECK_DEVICE(dk);
         TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
         CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
     } else {
@@ -675,7 +669,7 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
     if (dv_.has_value()) {
         dv = dv_.value();
         TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
-        TORCH_CHECK(dv.is_cuda(), "dv must be on CUDA device");
+        CHECK_DEVICE(dv);
         TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
         CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
     } else {
@@ -820,22 +814,17 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
     TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
     TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
 
-    TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
-    TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device");
-    TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device");
-    TORCH_CHECK(out.is_cuda(), "out tensor must be on CUDA device");
-    TORCH_CHECK(dout.is_cuda(), "dout tensor must be on CUDA device");
-    TORCH_CHECK(softmax_lse.is_cuda(), "softmax_lse tensor must be on CUDA device");
-    TORCH_CHECK(cu_seqlens_q.is_cuda(), "cu_seqlens_q must be on CUDA device");
-    TORCH_CHECK(cu_seqlens_k.is_cuda(), "cu_seqlens_k must be on CUDA device");
+    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
+    CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
+    CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
 
     TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
     TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
     TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
     TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
     TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
-    TORCH_CHECK(cu_seqlens_q.is_contiguous(), "cu_seqlens_q must be contiguous");
-    TORCH_CHECK(cu_seqlens_k.is_contiguous(), "cu_seqlens_k must be contiguous");
+    CHECK_CONTIGUOUS(cu_seqlens_q);
+    CHECK_CONTIGUOUS(cu_seqlens_k);
 
     const auto sizes = q.sizes();
 
@@ -873,7 +862,7 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
     if (dq_.has_value()) {
         dq = dq_.value();
         TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
-        TORCH_CHECK(dq.is_cuda(), "dq must be on CUDA device");
+        CHECK_DEVICE(dq);
         TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
         CHECK_SHAPE(dq, total_q, num_heads, head_size);
     } else {
@@ -882,7 +871,7 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
     if (dk_.has_value()) {
         dk = dk_.value();
         TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
-        TORCH_CHECK(dk.is_cuda(), "dk must be on CUDA device");
+        CHECK_DEVICE(dk);
         TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
         CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
     } else {
@@ -891,7 +880,7 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
     if (dv_.has_value()) {
         dv = dv_.value();
         TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
-        TORCH_CHECK(dv.is_cuda(), "dv must be on CUDA device");
+        CHECK_DEVICE(dv);
         TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
         CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
     } else {
@@ -1000,9 +989,12 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
                 c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
                 c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
                 c10::optional<const at::Tensor> &seqlens_k_, // batch_size
+                c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
+                c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
                 c10::optional<at::Tensor> &out_,             // batch_size x seqlen_q x num_heads x head_size
                 const float softmax_scale,
                 bool is_causal,
+                bool is_rotary_interleaved,   // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
                 int num_splits
                 ) {
 
@@ -1023,9 +1015,7 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
     TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
     TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype");
 
-    TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
-    TORCH_CHECK(kcache.is_cuda(), "Input tensor must be on CUDA device");
-    TORCH_CHECK(vcache.is_cuda(), "Input tensor must be on CUDA device");
+    CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
 
     TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
     TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
@@ -1071,7 +1061,7 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
     if (out_.has_value()) {
         out = out_.value();
         TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
-        TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device");
+        CHECK_DEVICE(out);
         TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
         CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
         if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
@@ -1118,8 +1108,7 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
         v = v_.value();
         TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query");
         TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query");
-        TORCH_CHECK(k.is_cuda(), "Key tensor must be on CUDA device");
-        TORCH_CHECK(v.is_cuda(), "Value tensor must be on CUDA device");
+        CHECK_DEVICE(k); CHECK_DEVICE(v);
         TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
         TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
         int seqlen_knew = k.size(1);
@@ -1147,13 +1136,40 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
     if (seqlens_k_.has_value()) {
         auto seqlens_k = seqlens_k_.value();
         TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
-        TORCH_CHECK(seqlens_k.is_cuda(), "seqlens_k must be on CUDA device");
-        TORCH_CHECK(seqlens_k.is_contiguous(), "seqlens_k must be contiguous");
+        CHECK_DEVICE(seqlens_k);
+        CHECK_CONTIGUOUS(seqlens_k);
         CHECK_SHAPE(seqlens_k, batch_size);
         params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
     }
     params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());
 
+    if (rotary_cos_.has_value()) {
+        TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
+        auto rotary_cos = rotary_cos_.value();
+        CHECK_DEVICE(rotary_cos);
+        params.rotary_dim = rotary_cos.size(1) * 2;
+        TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
+        TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
+        const int seqlen_ro = rotary_cos.size(0);
+        TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
+        CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
+        CHECK_CONTIGUOUS(rotary_cos);
+        TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
+
+        TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
+        auto rotary_sin = rotary_sin_.value();
+        CHECK_DEVICE(rotary_sin);
+        CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
+        CHECK_CONTIGUOUS(rotary_sin);
+        TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
+        params.rotary_cos_ptr = rotary_cos.data_ptr();
+        params.rotary_sin_ptr = rotary_sin.data_ptr();
+        params.is_rotary_interleaved = is_rotary_interleaved;
+    } else {
+        params.rotary_dim = 0;
+    }
+
+
     // This needs to match with run_mha_fwd_splitkv_dispatch
     const int block_n = is_sm90 || is_sm8x
         ? (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64))

+ 7 - 1
csrc/flash_attn/src/flash.h

@@ -67,7 +67,7 @@ struct Flash_fwd_params : public Qkv_params {
     void * __restrict__ softmax_lseaccum_ptr;
 
     // The dimensions.
-    int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
+    int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
 
     // The scaling factors for the kernel.
     float scale_softmax;
@@ -91,6 +91,10 @@ struct Flash_fwd_params : public Qkv_params {
     index_t knew_head_stride;
     index_t vnew_head_stride;
 
+    // The cos and sin matrices for rotary embedding.
+    void * __restrict__ rotary_cos_ptr;
+    void * __restrict__ rotary_sin_ptr;
+
     // The dropout probability (probability of keeping an activation).
     float p_dropout;
     // uint32_t p_dropout_in_uint;
@@ -114,6 +118,8 @@ struct Flash_fwd_params : public Qkv_params {
     // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
     bool is_seqlens_k_cumulative;
 
+    bool is_rotary_interleaved;
+
     int num_splits;  // For split-KV version
 };
 

+ 90 - 8
csrc/flash_attn/src/flash_fwd_kernel.h

@@ -744,10 +744,36 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
 
     // Prologue
 
+    // Copy from Knew to K, optionally apply rotary embedding.
+    typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary;
+    auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
+    typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont;
+    auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
     if constexpr (Append_KV) {
         // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
         // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
         // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
+        const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2);
+        Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
+                                  Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
+                                  make_stride(params.rotary_dim / 2, _1{}));
+        Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
+                                  Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
+                                  make_stride(params.rotary_dim / 2, _1{}));
+        Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
+                                      Shape<Int<kBlockN>, Int<kHeadDim>>{},
+                                      make_stride(params.rotary_dim / 2, _1{}));
+        Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
+                                      Shape<Int<kBlockN>, Int<kHeadDim>>{},
+                                      make_stride(params.rotary_dim / 2, _1{}));
+        Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
+        Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
+        Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
+        Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
+        // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); }
+        // if (cute::thread(8, 0)) { print_tensor(gCos); }
+        // if (cute::thread(0, 0)) { print_tensor(tRgCos); }
+
         const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
             + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
         const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
@@ -769,17 +795,39 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
 
         const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
         for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) {
-            flash::copy_w_min_idx<Is_even_K>(
-                tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
-            );
             flash::copy_w_min_idx<Is_even_K>(
                 tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
             );
-            tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
-            tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
             tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
             tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
+            if (params.rotary_dim == 0) {
+                flash::copy_w_min_idx<Is_even_K>(
+                    tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
+                );
+            } else {
+                if (params.is_rotary_interleaved) {
+                    // Don't clear OOB_K because we're writing to global memory
+                    flash::copy_rotary_interleaved<Is_even_K, /*Clear_OOB_K=*/false>(
+                        tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
+                        binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim
+                    );
+                    tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2));
+                    tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2));
+                } else {
+                    // Don't clear OOB_K because we're writing to global memory
+                    flash::copy_rotary_contiguous<Is_even_K, /*Clear_OOB_K=*/false>(
+                        tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
+                        binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim
+                    );
+                    tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2));
+                    tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2));
+
+                }
+            }
+            tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+            tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
         }
+        // Need this before we can read in K again, so that we'll see the updated K values.
         __syncthreads();
         if (n_block_max > n_block_copy_min) {
             tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride;
@@ -787,10 +835,44 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         }
     }
 
+    // Read Q from gmem to smem, optionally apply rotary embedding.
     Tensor tQrQ = make_fragment_like(tQgQ);
-    // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
-    flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
-                                       binfo.actual_seqlen_q - m_block * kBlockM);
+    if (!Append_KV || params.rotary_dim == 0) {
+        // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
+        flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
+                                           binfo.actual_seqlen_q - m_block * kBlockM);
+    } else {
+        const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
+        // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
+        // We do this by setting the row stride of gCos / gSin to 0.
+        Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
+                                  Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
+                                  make_stride(Is_causal ? params.rotary_dim / 2 : 0, _1{}));
+        Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
+                                  Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
+                                  make_stride(Is_causal ? params.rotary_dim / 2 : 0, _1{}));
+        Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
+                                  Shape<Int<kBlockM>, Int<kHeadDim>>{},
+                                  make_stride(Is_causal ? params.rotary_dim / 2 : 0, _1{}));
+        Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
+                                  Shape<Int<kBlockM>, Int<kHeadDim>>{},
+                                  make_stride(Is_causal ? params.rotary_dim / 2 : 0, _1{}));
+        Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
+        Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
+        Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
+        Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
+        if (params.is_rotary_interleaved) {
+            flash::copy_rotary_interleaved<Is_even_K>(
+                tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
+                0, params.d, params.rotary_dim
+            );
+        } else {
+            flash::copy_rotary_contiguous<Is_even_K>(
+                tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
+                0, params.d, params.rotary_dim
+            );
+        }
+    }
 
     int n_block = n_block_max - 1;
     // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.

+ 12 - 3
csrc/flash_attn/src/kernel_traits.h

@@ -142,11 +142,11 @@ struct Flash_fwd_kernel_traits : public Base {
         DefaultCopy
     >;
     using GmemTiledCopyQKV = decltype(
-        make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
+        make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
                         GmemLayoutAtom{},
                         Layout<Shape<_1, _8>>{}));  // Val layout, 8 vals per read
     using GmemTiledCopyO = decltype(
-        make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+        make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
                         GmemLayoutAtom{},
                         Layout<Shape<_1, _8>>{}));  // Val layout, 8 vals per store
     static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
@@ -155,7 +155,7 @@ struct Flash_fwd_kernel_traits : public Base {
                                    Stride<Int<kGmemThreadsPerRowP>, _1>>;
 
     using GmemTiledCopyP = decltype(
-        make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+        make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
                         GmemLayoutAtomP{},
                         Layout<Shape<_1, _8>>{}));  // Val layout, 8 vals per store
 
@@ -170,6 +170,15 @@ struct Flash_fwd_kernel_traits : public Base {
         make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
                         GmemLayoutAtomOaccum{},
                         Layout<Shape < _1, _4>>{}));  // Val layout, 4 vals per store
+    using GmemLayoutAtomRotcossin = GmemLayoutAtom;
+    using GmemTiledCopyRotcossin = decltype(
+        make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
+                        GmemLayoutAtomRotcossin{},
+                        Layout<Shape < _1, _4>>{}));  // Val layout, 4 vals per load
+    using GmemTiledCopyRotcossinCont = decltype(
+        make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
+                        GmemLayoutAtomRotcossin{},
+                        Layout<Shape < _1, _8>>{}));  // Val layout, 8 vals per load
 };
 
 // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.

+ 127 - 31
csrc/flash_attn/src/utils.h

@@ -355,65 +355,161 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-template <bool Is_2_sources=false, bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
-          typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
+template <bool Is_even_K=true,
+          typename Engine0, typename Layout0, typename Engine1, typename Layout1,
           typename Engine2, typename Layout2, typename Engine3, typename Layout3>
-inline __device__ void copy_2_sources(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S0,
-                                      Tensor<Engine0, Layout0> const &S1,
+inline __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
                                       Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
                                       Tensor<Engine3, Layout3> const &predicate_K,
-                                      const int max_MN=0, const int row_idx_switch=0) {
-    CUTE_STATIC_ASSERT_V(rank(S0) == Int<3>{} && rank(S1) == Int<3>{});
+                                      const int max_MN=0, const int min_MN=0) {
+    CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
     CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
-    CUTE_STATIC_ASSERT_V(size<0>(S0) == size<0>(D) && size<0>(S1) == size<0>(D)); // MMA
-    CUTE_STATIC_ASSERT_V(size<1>(S0) == size<1>(D) && size<1>(S1) == size<1>(D)); // MMA_M
-    CUTE_STATIC_ASSERT_V(size<2>(S0) == size<2>(D) && size<2>(S1) == size<2>(D)); // MMA_K
-    // There's no case where !Clear_OOB_K && Clear_OOB_MN
-    static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
-    // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", Is_2_sources, max_MN, row_idx_switch); }
-    // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, Is_2_sources = %d, max_MN = %d, row_idx_switch = %d\n", blockIdx.y, Is_2_sources, max_MN, row_idx_switch); }
+    CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));                     // MMA
+    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));                     // MMA_M
+    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));                     // MMA_K
+    // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
     #pragma unroll
-    for (int m = 0; m < size<1>(S0); ++m) {
-        auto &S = !Is_2_sources || get<0>(identity_MN(0, m, 0)) < row_idx_switch ? S0 : S1;
-        if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
+    for (int m = 0; m < size<1>(S); ++m) {
+        // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
+        if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
+            // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
             #pragma unroll
-            for (int k = 0; k < size<2>(S0); ++k) {
+            for (int k = 0; k < size<2>(S); ++k) {
                 if (Is_even_K || predicate_K(k)) {
-                    cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
+                    cute::copy(S(_, m, k), D(_, m, k));
+                }
+            }
+        }
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <bool Is_even_K=true, bool Clear_OOB_K=true,
+          typename Engine0, typename Layout0, typename Engine1, typename Layout1,
+          typename Engine2, typename Layout2, typename Engine3, typename Layout3>
+inline __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S,
+                                               Tensor<Engine1, Layout1> &D,
+                                               Tensor<Engine2, Layout2> const &Cos,
+                                               Tensor<Engine2, Layout2> const &Sin,
+                                               Tensor<Engine3, Layout3> const &identity_MN,
+                                               const int max_MN, const int min_MN,
+                                               const int dim, const int rotary_dim) {
+    CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
+    CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
+    CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));                     // MMA
+    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));                     // MMA_M
+    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));                     // MMA_K
+    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos));                     // MMA_M
+    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos));                     // MMA_K
+    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin));                     // MMA_M
+    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin));                     // MMA_K
+    CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));                     // MMA_K
+    static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2);
+    static_assert(decltype(size<0>(Cos))::value % 2 == 0);  // Since we do fast conversion from fp16/bf16 to fp32
+    Tensor rCos = make_fragment_like(Cos);
+    Tensor rSin = make_fragment_like(Sin);
+    Tensor rS = make_fragment_like(S);
+    #pragma unroll
+    for (int m = 0; m < size<1>(S); ++m) {
+        if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
+            #pragma unroll
+            for (int k = 0; k < size<2>(S); ++k) {
+                if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
+                    cute::copy(S(_, m, k), rS(_, m, k));
+                    if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
+                        cute::copy(Cos(_, m, k), rCos(_, m, k));
+                        cute::copy(Sin(_, m, k), rSin(_, m, k));
+                        Tensor S_fp32 = convert_type<float>(rS(_, m, k));
+                        Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
+                        Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
+                        #pragma unroll
+                        for (int i = 0; i < size<0>(rS) / 2; ++i) {
+                            float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i);
+                            float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i);
+                            S_fp32(2 * i) = real;
+                            S_fp32(2 * i + 1) = imag;
+                        }
+                        // Idk but I need to copy for the convert_type to work
+                        Tensor S_fp32_copy = make_fragment_like(S_fp32);
+                        cute::copy(S_fp32, S_fp32_copy);
+                        using T = typename Engine0::value_type;
+                        Tensor S_og_type = convert_type<T>(S_fp32_copy);
+                        cute::copy(S_og_type, rS(_, m, k));
+                    }
+                    cute::copy(rS(_, m, k), D(_, m, k));
                 } else if (Clear_OOB_K) {
                     cute::clear(D(_, m, k));
                 }
             }
-        } else if (Clear_OOB_MN) {
-            cute::clear(D(_, m, _));
         }
     }
 }
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-template <bool Is_even_K=true,
+template <bool Is_even_K=true, bool Clear_OOB_K=true,
           typename Engine0, typename Layout0, typename Engine1, typename Layout1,
           typename Engine2, typename Layout2, typename Engine3, typename Layout3>
-inline __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
-                                      Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
-                                      Tensor<Engine3, Layout3> const &predicate_K,
-                                      const int max_MN=0, const int min_MN=0) {
+inline __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S,
+                                              Tensor<Engine1, Layout1> &D,
+                                              Tensor<Engine2, Layout2> const &Cos,
+                                              Tensor<Engine2, Layout2> const &Sin,
+                                              Tensor<Engine3, Layout3> const &identity_MN,
+                                              const int max_MN, const int min_MN,
+                                              const int dim, const int rotary_dim) {
     CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
     CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
     CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));                     // MMA
     CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));                     // MMA_M
     CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));                     // MMA_K
-    // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
+    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos));                     // MMA_M
+    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos));                     // MMA_K
+    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin));                     // MMA_M
+    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin));                     // MMA_K
+    CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos));                     // MMA
+    CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));
+    static_assert(decltype(size<0>(Cos))::value % 2 == 0);  // Since we do fast conversion from fp16/bf16 to fp32
+    Tensor rCos = make_fragment_like(Cos);
+    Tensor rSin = make_fragment_like(Sin);
+    Tensor rS = make_fragment_like(S);
+    Tensor rS_other = make_fragment_like(rS(_, 0, 0));
     #pragma unroll
     for (int m = 0; m < size<1>(S); ++m) {
-        // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
         if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
-            // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
             #pragma unroll
             for (int k = 0; k < size<2>(S); ++k) {
-                if (Is_even_K || predicate_K(k)) {
-                    cute::copy(S(_, m, k), D(_, m, k));
+                if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
+                    cute::copy(S(_, m, k), rS(_, m, k));
+                    if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
+                        const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2;
+                        Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout());
+                        cute::copy(gS_other, rS_other);
+                        // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }
+                        Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout());
+                        Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout());
+                        cute::copy(gCos, rCos(_, m, k));
+                        cute::copy(gSin, rSin(_, m, k));
+                        // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }
+                        Tensor S_fp32 = convert_type<float>(rS(_, m, k));
+                        Tensor S_other_fp32 = convert_type<float>(rS_other);
+                        Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
+                        Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
+                        #pragma unroll
+                        for (int i = 0; i < size<0>(rS); ++i) {
+                            S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i));
+                        }
+                        // Idk but I need to copy for the convert_type to work
+                        Tensor S_fp32_copy = make_fragment_like(S_fp32);
+                        cute::copy(S_fp32, S_fp32_copy);
+                        using T = typename Engine0::value_type;
+                        Tensor S_og_type = convert_type<T>(S_fp32_copy);
+                        cute::copy(S_og_type, rS(_, m, k));
+                        // if (cute::thread0()) { print_tensor(rS(_, m, k)); }
+                    }
+                    cute::copy(rS(_, m, k), D(_, m, k));
+                } else if (Clear_OOB_K) {
+                    cute::clear(D(_, m, k));
                 }
             }
         }
@@ -422,4 +518,4 @@ inline __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-}  // namespace flash
+}  // namespace flash

+ 1 - 1
csrc/ft_attention/ft_attention.cpp

@@ -175,7 +175,7 @@ torch::Tensor single_query_attention(const torch::Tensor q,
         TORCH_CHECK(rotary_sin_.has_value());
         auto rotary_sin = rotary_sin_.value();
         CHECK_DEVICE(rotary_sin);
-        CHECK_SHAPE(rotary_cos, batch_size, rotary_embedding_dim / 2);
+        CHECK_SHAPE(rotary_sin, batch_size, rotary_embedding_dim / 2);
         CHECK_CONTIGUOUS(rotary_sin);
         TORCH_CHECK(rotary_sin.scalar_type() == input_type);
     }

+ 32 - 2
flash_attn/flash_attn_interface.py

@@ -800,9 +800,12 @@ def flash_attn_with_kvcache(
     v_cache,
     k=None,
     v=None,
+    rotary_cos=None,
+    rotary_sin=None,
     cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
     softmax_scale=None,
     causal=False,
+    rotary_interleaved=True,
     num_splits=0,
 ):
     """
@@ -815,7 +818,13 @@ def flash_attn_with_kvcache(
     For example, the KV cache could be pre-allocated with the max sequence length, and you can use
     cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
 
-    Does not support backward pass.
+    Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be rotated
+    by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
+    If causal, the query @q will be rotated by rotary_cos and rotary_sin at indices cache_seqlens,
+    cache_seqlens + 1, etc. If not causal, the query @q will be rotated by rotary_cos and rotary_sin
+    at indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
+
+    See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
 
     Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
     than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
@@ -834,6 +843,8 @@ def flash_attn_with_kvcache(
         1 1
     If the row of the mask is all zero, the output will be zero.
 
+    Note: Does not support backward pass.
+
     Arguments:
         q: (batch_size, seqlen, nheads, headdim)
         k_cache: (batch_size, seqlen_cache, nheads_k, headdim)
@@ -841,11 +852,18 @@ def flash_attn_with_kvcache(
         k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
             k with k_cache, starting at the indices specified by cache_seqlens.
         v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
+        rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
+            to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
+        rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
         cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
             KV cache.
         softmax_scale: float. The scaling of QK^T before applying softmax.
             Default to 1 / sqrt(headdim).
         causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
+        rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
+            If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
+            rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
+            (i.e. GPT-NeoX style).
         num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
            If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
            to automatically determine the number of splits.
@@ -865,6 +883,18 @@ def flash_attn_with_kvcache(
             (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
         )
     out, softmax_lse = flash_attn_cuda.fwd_kvcache(
-        q, k_cache, v_cache, k, v, cache_seqlens, None, softmax_scale, causal, num_splits
+        q,
+        k_cache,
+        v_cache,
+        k,
+        v,
+        cache_seqlens,
+        rotary_cos,
+        rotary_sin,
+        None,
+        softmax_scale,
+        causal,
+        rotary_interleaved,
+        num_splits,
     )
     return out

+ 1 - 1
tests/models/test_gpt.py

@@ -280,7 +280,7 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
 
 @pytest.mark.parametrize("seqlen,maxlen", [(10, 20), (30, 150), (3000, 3400), (14000, 15000)])
 # @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
-@pytest.mark.parametrize("rotary", [None, "interleaved", "block"])
+@pytest.mark.parametrize("rotary", [None, "interleaved", "contiguous"])
 # @pytest.mark.parametrize('rotary', [None])
 @pytest.mark.parametrize("fused_ft_kernel", [False, True])
 # @pytest.mark.parametrize("fused_ft_kernel", [False])

+ 69 - 10
tests/test_flash_attn.py

@@ -15,6 +15,7 @@ from flash_attn import (
 )
 from flash_attn.bert_padding import pad_input, unpad_input
 from flash_attn.flash_attn_interface import _get_block_size
+from flash_attn.layers.rotary import apply_rotary_emb
 
 MAX_HEADDIM_SM8x = 192
 
@@ -1497,12 +1498,16 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
 @pytest.mark.parametrize("new_kv", [False, True])
 # @pytest.mark.parametrize("new_kv", [True])
 @pytest.mark.parametrize("causal", [False, True])
-# @pytest.mark.parametrize("causal", [False])
+# @pytest.mark.parametrize("causal", [True])
 @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
+# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
+@pytest.mark.parametrize("rotary_interleaved", [False, True])
+# @pytest.mark.parametrize("rotary_interleaved", [False])
+@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
+# @pytest.mark.parametrize("rotary_fraction", [1.0])
 @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
 # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
 # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
-# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
 # @pytest.mark.parametrize('d', [56, 80])
 # @pytest.mark.parametrize("d", [64])
 @pytest.mark.parametrize(
@@ -1523,15 +1528,29 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
 )
 # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
 def test_flash_attn_kvcache(
-    seqlen_q, seqlen_k, d, seqlen_new_eq_seqlen_q, causal, new_kv, mha_type, num_splits, dtype
+    seqlen_q,
+    seqlen_k,
+    d,
+    rotary_fraction,
+    rotary_interleaved,
+    seqlen_new_eq_seqlen_q,
+    causal,
+    new_kv,
+    mha_type,
+    num_splits,
+    dtype,
 ):
     if seqlen_q > seqlen_k and new_kv:
         pytest.skip()
+    if not new_kv and rotary_fraction > 0.0:
+        pytest.skip()
     device = "cuda"
     # set seed
     torch.random.manual_seed(0)
     batch_size = 2
     nheads = 6
+    # rotary_dim must be a multiple of 16, and must be <= d
+    rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
     nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
     assert nheads % nheads_k == 0
     q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
@@ -1545,12 +1564,42 @@ def test_flash_attn_kvcache(
     v_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype)
     cache_seqlens = torch.randint(
         0,
-        (seqlen_k - seqlen_new + 1) if new_kv else (seqlen_k + 1),
+        # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
+        (seqlen_k - (seqlen_q if causal and rotary_dim > 1 else seqlen_new) + 1)
+        if new_kv
+        else (seqlen_k + 1),
         (batch_size,),
         dtype=torch.int32,
         device=device,
     )
     # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
+    if rotary_dim > 0:
+        angle = torch.rand(seqlen_k, rotary_dim // 2, device=device) * 2 * math.pi
+        cos = torch.cos(angle).to(dtype=dtype)
+        sin = torch.sin(angle).to(dtype=dtype)
+        if causal:
+            q_ro = apply_rotary_emb(
+                q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
+            )
+        else:
+            q_ro = rearrange(
+                apply_rotary_emb(
+                    rearrange(q, "b s h d -> b 1 (s h) d"),
+                    cos,
+                    sin,
+                    seqlen_offsets=cache_seqlens,
+                    interleaved=rotary_interleaved,
+                ),
+                "b 1 (s h) d -> b s h d",
+                s=seqlen_q,
+            )
+        # q_ro = q
+        k_ro = apply_rotary_emb(
+            k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
+        )
+    else:
+        cos, sin = None, None
+        q_ro, k_ro = q, k
     # k_cache[:, 64:] = -1
     k_cache_ref = k_cache.clone()
     v_cache_ref = v_cache.clone()
@@ -1560,12 +1609,22 @@ def test_flash_attn_kvcache(
         update_mask = torch.logical_and(
             cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
         )
-        k_cache_ref[update_mask] = rearrange(k, "b s ... -> (b s) ...")
+        k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...")
         v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...")
     k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
     v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
     out = flash_attn_with_kvcache(
-        q, k_cache, v_cache, k, v, cache_seqlens, causal=causal, num_splits=num_splits
+        q,
+        k_cache,
+        v_cache,
+        k,
+        v,
+        cos,
+        sin,
+        cache_seqlens,
+        causal=causal,
+        rotary_interleaved=rotary_interleaved,
+        num_splits=num_splits,
     )
     # out = flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal)
     # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal)
@@ -1577,10 +1636,10 @@ def test_flash_attn_kvcache(
     # probs = torch.softmax(qk, dim=-1)
     key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
     out_ref, _ = attention_ref(
-        q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal
+        q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal
     )
     out_pt, _ = attention_ref(
-        q,
+        q_ro,
         k_cache_rep,
         v_cache_rep,
         None,
@@ -1598,10 +1657,10 @@ def test_flash_attn_kvcache(
 
     # Check that FlashAttention's numerical error is at most twice the numerical error
     # of a Pytorch implementation.
-    assert (out - out_ref).abs().max().item() <= 3 * (out_pt - out_ref).abs().max().item() + 1e-5
     if new_kv:
-        assert torch.equal(k_cache, k_cache_ref)
+        assert torch.allclose(k_cache, k_cache_ref, rtol=1e-3, atol=1e-3)
         assert torch.equal(v_cache, v_cache_ref)
+    assert (out - out_ref).abs().max().item() <= 3 * (out_pt - out_ref).abs().max().item() + 1e-5
 
 
 # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))