소스 검색

Support AMD ROCm on FlashAttention 2 (#1010)

* Support ck in fmha

* Add ck submodule

* Do not return lse if return_softmax == false

* Use receipt to speed up ck compile time

* Integrate new version of ck_tile

* Support dropout for mha_fwd()

* Add dropout to mha_varlen_fwd()

* Update ck to develop

* Extract padding function for dropout randval

* Extract randval transformation function

* Sync the code structure and coding style with FA

* Remove this line, c++ api will handle this.
Sync with test_flash_attn.py

* fix compile error

* Add mha_bwd

* Generate dropout seed and offset from user generator

* update CK

* Add mha_varlen_bwd

* Use same python as build flash-attn to generate ck kernel

* Fix bug of group mode fwd about returning softmax lse

* larger the test tollerance

* Add test_flash_attn_output() and test_flash_attn_varlen_output()

* Always fill softmax_lse

* Remove duplicate benchmark script, since we already implement mha_bwd

* Refine get value from tuple

* Use default parameter for stream_config

* unblock all platform

* Add comment

* refine the test code

* Refine naming

* Add unpack to namespace

* Do not hardcode the warp size 64

* Add more targets

* Add README

* Optimize mha_fwd if seqlen_q == 1

* Support get_wheel_url for rocm

* Detect rocm environment by pytorch's IS_HIP_EXTENSION

* update to lastest ck

* Add necessary compile flag

* Sync the api with upstream FA

---------

Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: Yichen Yan <wenji.yyc@alibaba-inc.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: Yichen Yan <oraluben@outlook.com>
rocking 8 달 전
부모
커밋
d8f104e97a

+ 3 - 0
.gitmodules

@@ -1,3 +1,6 @@
 [submodule "csrc/cutlass"]
 	path = csrc/cutlass
 	url = https://github.com/NVIDIA/cutlass.git
+[submodule "csrc/composable_kernel"]
+	path = csrc/composable_kernel
+	url = https://github.com/ROCm/composable_kernel.git

+ 27 - 0
README.md

@@ -434,6 +434,33 @@ This new release of FlashAttention-2 has been tested on several GPT-style
 models, mostly on A100 GPUs.
 
 If you encounter bugs, please open a GitHub Issue!
+## AMD GPU/ROCm Support
+ROCm version use [composable_kernel](https://github.com/ROCm/composable_kernel) as backend. It provides the implementation of FlashAttention-2.
+
+## Installation and features
+Requirements:
+- ROCm 6.0+
+- PyTorch 1.12.1+
+
+We recommend the
+[Pytorch](https://hub.docker.com/r/rocm/pytorch)
+container from ROCm, which has all the required tools to install FlashAttention.
+
+To compile from source:
+```sh
+python setup.py install
+```
+
+FlashAttention-2 on ROCm currently supports:
+1. MI200 or MI300 GPUs.
+2. Datatype fp16 and bf16
+3. Forward's head dimensions up to 256. Backward head dimensions up to 128.
+
+## Tests
+To run the tests:
+```sh
+pytest tests/test_flash_attn_ck.py
+```
 
 ## Citation
 If you use this codebase, or otherwise found our work valuable, please cite:

+ 1 - 0
csrc/composable_kernel

@@ -0,0 +1 @@
+Subproject commit 8182976c37433808b5e3a27a6536d1b74b0c23a1

+ 99 - 0
csrc/flash_attn_ck/flash_api.cpp

@@ -0,0 +1,99 @@
+/******************************************************************************
+ * Copyright (c) 2024, Tri Dao.
+ ******************************************************************************/
+
+#include "flash_common.hpp"
+
+std::vector<at::Tensor>
+mha_fwd(at::Tensor &q,
+        const at::Tensor &k,
+        const at::Tensor &v,
+        c10::optional<at::Tensor> &out_,
+        c10::optional<at::Tensor> &alibi_slopes_,
+        const float p_dropout,
+        const float softmax_scale,
+        bool is_causal,
+        int window_size_left,
+        int window_size_right,
+        const float softcap,
+        const bool return_softmax,
+        c10::optional<at::Generator> gen_);
+
+std::vector<at::Tensor>
+mha_varlen_fwd(at::Tensor &q,                            // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
+               const at::Tensor &k,                      // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
+               const at::Tensor &v,                      // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
+               c10::optional<at::Tensor> &out_,          // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
+               const at::Tensor &cu_seqlens_q,           // b+1
+               const at::Tensor &cu_seqlens_k,           // b+1
+               c10::optional<at::Tensor> &seqused_k,     // b. If given, only this many elements of each batch element's keys are used.
+               c10::optional<const at::Tensor> &leftpad_k_, // batch_size
+               c10::optional<at::Tensor> &block_table_,  // batch_size x max_num_blocks_per_seq
+               c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
+               int max_seqlen_q,
+               const int max_seqlen_k,
+               const float p_dropout,
+               const float softmax_scale,
+               const bool zero_tensors,
+               bool is_causal,
+               int window_size_left,
+               int window_size_right,
+               const float softcap,
+               const bool return_softmax,
+               c10::optional<at::Generator> gen_);
+
+std::vector<at::Tensor>
+mha_bwd(const at::Tensor &dout,                   // batch_size x seqlen_q x num_heads, x head_size_og
+        const at::Tensor &q,                      // batch_size x seqlen_q x num_heads x head_size
+        const at::Tensor &k,                      // batch_size x seqlen_k x num_heads_k x head_size
+        const at::Tensor &v,                      // batch_size x seqlen_k x num_heads_k x head_size
+        const at::Tensor &out,                    // batch_size x seqlen_q x num_heads x head_size
+        const at::Tensor &softmax_lse,            // b x h x seqlen_q
+        c10::optional<at::Tensor> &dq_,           // batch_size x seqlen_q x num_heads x head_size
+        c10::optional<at::Tensor> &dk_,           // batch_size x seqlen_k x num_heads_k x head_size
+        c10::optional<at::Tensor> &dv_,           // batch_size x seqlen_k x num_heads_k x head_size
+        c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
+        const float p_dropout,                    // probability to drop
+        const float softmax_scale,
+        const bool is_causal,
+        int window_size_left,
+        int window_size_right,
+        const float softcap,
+        const bool deterministic,
+        c10::optional<at::Generator> gen_,
+        c10::optional<at::Tensor> &rng_state);
+
+std::vector<at::Tensor>
+mha_varlen_bwd(const at::Tensor &dout,                   // total_q x num_heads x head_size
+               const at::Tensor &q,                      // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
+               const at::Tensor &k,                      // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
+               const at::Tensor &v,                      // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
+               const at::Tensor &out,                    // total_q x num_heads x head_size
+               const at::Tensor &softmax_lse,            // b x h x s   softmax logsumexp
+               c10::optional<at::Tensor> &dq_,           // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
+               c10::optional<at::Tensor> &dk_,           // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
+               c10::optional<at::Tensor> &dv_,           // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
+               const at::Tensor &cu_seqlens_q,           // b+1
+               const at::Tensor &cu_seqlens_k,           // b+1
+               c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
+               const int max_seqlen_q,
+               const int max_seqlen_k, // max sequence length to choose the kernel
+               const float p_dropout,  // probability to drop
+               const float softmax_scale,
+               const bool zero_tensors,
+               const bool is_causal,
+               int window_size_left,
+               int window_size_right,
+               const float softcap,
+               const bool deterministic,
+               c10::optional<at::Generator> gen_,
+               c10::optional<at::Tensor> &rng_state);
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+    m.doc() = "FlashAttention";
+    m.def("fwd", &mha_fwd, "Forward pass");
+    m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
+    m.def("bwd", &mha_bwd, "Backward pass");
+    m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
+}

+ 38 - 0
csrc/flash_attn_ck/flash_common.hpp

@@ -0,0 +1,38 @@
+/******************************************************************************
+ * Copyright (c) 2024, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
+#include <torch/python.h>
+#include <torch/nn/functional.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGuard.h>
+
+#ifdef OLD_GENERATOR_PATH
+#include <ATen/CUDAGeneratorImpl.h>
+#else
+#include <ATen/cuda/CUDAGeneratorImpl.h>
+#endif
+
+
+#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
+#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+
+namespace flash {
+// Copy from PyTorch
+// https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17
+static std::tuple<uint64_t, uint64_t> unpack(at::PhiloxCudaState arg) {
+  if (arg.captured_) {
+    // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".
+    // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel.
+    // For most threads' reads it will hit in cache, so it shouldn't hurt performance.
+    return std::make_tuple(static_cast<uint64_t>(*arg.seed_.ptr), static_cast<uint64_t>(*(arg.offset_.ptr) + arg.offset_intragraph_));
+  } else {
+    return std::make_tuple(arg.seed_.val, arg.offset_.val);
+  }
+}
+
+} // namespace flash

+ 379 - 0
csrc/flash_attn_ck/mha_bwd.cpp

@@ -0,0 +1,379 @@
+/******************************************************************************
+ * Copyright (c) 2024, Tri Dao.
+ ******************************************************************************/
+
+#include "flash_common.hpp"
+
+#include "fmha_bwd.hpp"
+#include "mask.hpp"
+
+fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
+                                       std::string dtype,
+                                       int head_size,
+                                       bool has_dropout,
+                                       bool enable_alibi)
+{
+    return fmha_bwd_traits{head_size,
+                           head_size,
+                           dtype,
+                           false, // is_group_mode
+                           mask.type,
+                           enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
+                           false,    // has_dbias
+                           has_dropout};
+}
+
+fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
+                                   // sizes
+                                   const int b,
+                                   const int seqlen_q,
+                                   const int seqlen_k,
+                                   const int h,
+                                   const int h_k,
+                                   const int hdim,
+                                   // device pointers
+                                   const at::Tensor q,
+                                   const at::Tensor k,
+                                   const at::Tensor v,
+                                   c10::optional<at::Tensor> &alibi_slopes_,
+                                   const at::Tensor out,
+                                   const at::Tensor softmax_lse,
+                                   const at::Tensor dout,
+                                   at::Tensor d,
+                                   at::Tensor dq,
+                                   at::Tensor dk,
+                                   at::Tensor dv,
+                                   float softmax_scale,
+                                   float p_dropout,
+                                   uint64_t drop_seed,
+                                   uint64_t drop_offset)
+{
+    // q: (batch_size, seqlen_q, nheads, hdim)
+    // k: (batch_size, seqlen_k, nheads_k, hdim)
+    // v: (batch_size, seqlen_k, nheads_k, hdim)
+    // o: (batch_size, seqlen_q, nheads, hdim)
+    // dq: (batch_size, seqlen_q, nheads, hdim)
+    // dk_expanded: (batch_size, seqlen_k, nheads, hdim)
+    // dv_expanded: (batch_size, seqlen_k, nheads, hdim)
+    // do: (batch_size, seqlen_q, nheads, hdim)
+
+    // alibi_slopes:(batch_size, nheads) or (nhead)
+    // lse: (batch_size, nheads, seqlen_q)
+    // d: (batch_size, nheads, seqlen_q)
+
+    ck_tile::index_t stride_q = q.stride(1);
+    ck_tile::index_t stride_k = k.stride(1);
+    ck_tile::index_t stride_v = v.stride(1);
+    ck_tile::index_t stride_o = out.stride(1);
+    ck_tile::index_t stride_do = dout.stride(1);
+    ck_tile::index_t stride_dk = dk.stride(1);
+    ck_tile::index_t stride_dv = dv.stride(1);
+
+    ck_tile::index_t nhead_stride_q = q.stride(2);
+    ck_tile::index_t nhead_stride_k = k.stride(2);
+    ck_tile::index_t nhead_stride_v = v.stride(2);
+    ck_tile::index_t nhead_stride_o = out.stride(2);
+    ck_tile::index_t nhead_stride_do = dout.stride(2);
+    ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1);
+
+    ck_tile::index_t batch_stride_q = q.stride(0);
+    ck_tile::index_t batch_stride_k = k.stride(0);
+    ck_tile::index_t batch_stride_v = v.stride(0);
+    ck_tile::index_t batch_stride_o = out.stride(0);
+    ck_tile::index_t batch_stride_do = dout.stride(0);
+    ck_tile::index_t batch_stride_lse = softmax_lse.stride(0);
+    ck_tile::index_t batch_stride_dk = dk.stride(0);
+    ck_tile::index_t batch_stride_dv = dv.stride(0);
+
+    float p_undrop = 1.0 - p_dropout;
+
+    void *alibi_slopes_ptr = nullptr;
+    ck_tile::index_t stride_alibi_slopes = 0;
+
+    if (alibi_slopes_.has_value()) {
+        auto alibi_slopes = alibi_slopes_.value();
+        CHECK_DEVICE(alibi_slopes);
+        TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
+        TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
+        alibi_slopes_ptr = alibi_slopes.data_ptr();
+        stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
+    }
+
+    return fmha_bwd_args{q.data_ptr(),
+                         k.data_ptr(),
+                         v.data_ptr(),
+                         alibi_slopes_ptr, // bias
+                         out.data_ptr(),
+                         softmax_lse.data_ptr(),
+                         dout.data_ptr(),
+                         d.data_ptr(),
+                         nullptr, // rand_val
+                         dq.data_ptr(),
+                         dk.data_ptr(),
+                         dv.data_ptr(),
+                         nullptr, // dbias
+                         nullptr, // seqstart_q
+                         nullptr, // seqstart_k
+                         nullptr, // seqlen_k_ptr
+                         seqlen_q,
+                         seqlen_k,
+                         b,
+                         seqlen_q, // max_seqlen_q
+                         seqlen_k, // max_seqlen_k
+                         hdim, // hdim_q
+                         hdim, // hdim_v
+                         h, // nhead
+                         h_k, // nhead_k
+                         softmax_scale,
+                         stride_q,
+                         stride_k,
+                         stride_v,
+                         stride_alibi_slopes,
+                         stride_o,
+                         0, // stride_randval
+                         stride_do,
+                         stride_dk,
+                         stride_dv,
+                         0, // stride_dbias, FA without bias
+                         nhead_stride_q,
+                         nhead_stride_k,
+                         nhead_stride_v,
+                         0, // nhead_stride_bias, FA without bias
+                         nhead_stride_o,
+                         0, // nhead_stride_randval
+                         nhead_stride_do,
+                         nhead_stride_lse,
+                         0, // nhead_stride_dbias, FA without dbias
+                         batch_stride_q,
+                         batch_stride_k,
+                         batch_stride_v,
+                         0  , // batch_stride_bias, FA without bias
+                         batch_stride_o,
+                         0, // batch_stride_randval
+                         batch_stride_do,
+                         batch_stride_lse,
+                         batch_stride_dk,
+                         batch_stride_dv,
+                         0  , // batch_stride_dbias, FA without dbias
+                         mask.left,
+                         mask.right,
+                         static_cast<ck_tile::index_t>(mask.type),
+                         p_dropout,
+                         p_undrop,
+                         false, // s_randval
+                         {drop_seed, drop_offset}};
+}
+
+std::vector<at::Tensor>
+mha_bwd(const at::Tensor &dout,                   // batch_size x seqlen_q x num_heads, x head_size_og
+        const at::Tensor &q,                      // batch_size x seqlen_q x num_heads x head_size
+        const at::Tensor &k,                      // batch_size x seqlen_k x num_heads_k x head_size
+        const at::Tensor &v,                      // batch_size x seqlen_k x num_heads_k x head_size
+        const at::Tensor &out,                    // batch_size x seqlen_q x num_heads x head_size
+        const at::Tensor &softmax_lse,            // b x h x seqlen_q
+        c10::optional<at::Tensor> &dq_,           // batch_size x seqlen_q x num_heads x head_size
+        c10::optional<at::Tensor> &dk_,           // batch_size x seqlen_k x num_heads_k x head_size
+        c10::optional<at::Tensor> &dv_,           // batch_size x seqlen_k x num_heads_k x head_size
+        c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
+        const float p_dropout,                    // probability to drop
+        const float softmax_scale,
+        const bool is_causal,
+        int window_size_left,
+        int window_size_right,
+        const float /*softcap*/,
+        const bool deterministic,
+        c10::optional<at::Generator> gen_,
+        c10::optional<at::Tensor> &rng_state)
+{
+#ifdef FLASHATTENTION_DISABLE_BACKWARD
+    TORCH_CHECK(false, "This flash attention build does not support backward.");
+#endif
+    if (is_causal) { window_size_right = 0; }
+
+    bool is_dropout = p_dropout > 0.0;
+    auto stream = at::cuda::getCurrentHIPStream().stream();
+
+    auto q_dtype = q.dtype();
+    TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
+                "FlashAttention only support fp16 and bf16 data type");
+
+    TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
+    TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
+    TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
+    TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
+
+    std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
+
+    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
+    CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
+
+    TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
+    TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
+    TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
+    TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
+    TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
+
+    const auto sizes = q.sizes();
+
+    const int batch_size = sizes[0];
+    const int seqlen_q = sizes[1];
+    const int num_heads = sizes[2];
+    const int head_size_og = dout.size(3);  // unpadded hdim
+    const int head_size_8x = sizes[3];
+    const int seqlen_k = k.size(1);
+    const int num_heads_k = k.size(2);
+    TORCH_CHECK(batch_size > 0, "batch size must be positive");
+    TORCH_CHECK(head_size_8x % 8 == 0, "head_size_8x should be a multiple of 8");
+    TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128");
+    TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
+
+    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+    TORCH_CHECK(head_size_8x == round_multiple(head_size_og, 8), "head_size_8x must be head_size_og rounded to a multiple of 8");
+
+    if (window_size_left >= seqlen_k) { window_size_left = -1; }
+    if (window_size_right >= seqlen_k) { window_size_right = -1; }
+
+    mask_info mask;
+    if (is_causal) {
+        std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
+        mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual
+    }
+    else if (window_size_left == -1 && window_size_right == -1) {
+        mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask
+    }
+    else {
+        // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
+        std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
+        mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local
+    }
+
+    // q, k, v, out had been padded in mha_fwd
+    // dq_, dk_, dv_ are also padded tensor
+    CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_8x);
+    CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_8x);
+    CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_8x);
+    CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_8x);
+    CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
+
+    at::Tensor dq, dk, dv;
+    if (dq_.has_value()) {
+        dq = dq_.value();
+        TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
+        CHECK_DEVICE(dq);
+        TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
+        CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size_8x);
+    } else {
+        dq = torch::empty_like(q);
+    }
+    if (dk_.has_value()) {
+    dk = dk_.value();
+    TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
+    CHECK_DEVICE(dk);
+    TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
+    CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size_8x);
+    } else {
+        dk = torch::empty_like(k);
+    }
+    if (dv_.has_value()) {
+        dv = dv_.value();
+        TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
+        CHECK_DEVICE(dv);
+        TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
+        CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_8x);
+    } else {
+        dv = torch::empty_like(v);
+    }
+
+    at::Tensor dout_padded;
+    if (head_size_og % 8 != 0) {
+        dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+    } else {
+        dout_padded = dout;
+    }
+
+    // Cast to char to avoid compiler warning about narrowing
+    at::cuda::CUDAGuard device_guard{(char)q.get_device()};
+
+    auto opts = q.options();
+    auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
+    // TODO - CK does not support dq_accum
+
+    at::Tensor dk_expanded, dv_expanded;
+    if (num_heads_k != num_heads) {  // MQA / GQA
+        dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size_8x}, opts);
+        dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size_8x}, opts);
+    } else {
+        dk_expanded = dk;
+        dv_expanded = dv;
+    }
+
+    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
+        gen_, at::cuda::detail::getDefaultCUDAGenerator());
+
+    uint64_t drop_seed = 1, drop_offset = 0;
+    int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
+
+    if (rng_state.has_value()) {
+        uint64_t* d = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
+        drop_seed = d[0];
+        drop_offset = d[1];
+    } else if(is_dropout) {
+        // See Note [Acquire lock when using random generators]
+        std::lock_guard<std::mutex> lock(gen->mutex_);
+        auto philox_args = gen->philox_cuda_state(counter_offset);
+        std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
+    }
+
+    if (seqlen_q > 0) {
+        ck_tile::stream_config stream_config{stream};
+        dq.zero_(); // ck use atomic operation on dq
+
+        auto traits =
+            get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value());
+
+        auto args =
+            get_ck_fmha_bwd_args(
+                mask,
+                batch_size,
+                seqlen_q,
+                seqlen_k,
+                num_heads,
+                num_heads_k,
+                head_size_8x,
+                q,
+                k,
+                v,
+                alibi_slopes_,
+                out,
+                softmax_lse,
+                dout_padded,
+                softmax_d,
+                dq,
+                dk_expanded,
+                dv_expanded,
+                softmax_scale,
+                p_dropout,
+                drop_seed,
+                drop_offset);
+
+        fmha_bwd(traits, args, stream_config);
+    } else {
+        // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
+        dk_expanded.zero_();
+        dv_expanded.zero_();
+        softmax_d.zero_();
+    }
+
+    // For MQA/GQA we need to sum dK and dV across the groups
+    if (num_heads_k != num_heads) {
+        at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {3});
+        at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {3});
+    }
+    if (head_size_og % 8 != 0) {
+        dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
+        dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
+        dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
+    }
+
+    return { dq, dk, dv, softmax_d };
+}

+ 348 - 0
csrc/flash_attn_ck/mha_fwd.cpp

@@ -0,0 +1,348 @@
+/******************************************************************************
+ * Copyright (c) 2024, Tri Dao.
+ ******************************************************************************/
+
+#include "flash_common.hpp"
+
+#include "fmha_fwd.hpp"
+#include "mask.hpp"
+
+fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask,
+                                       std::string dtype,
+                                       int head_size,
+                                       bool has_dropout,
+                                       bool has_lse,
+                                       bool enable_alibi)
+{
+    return fmha_fwd_traits{head_size,
+                           head_size,
+                           dtype,
+                           false, // is_group_mode
+                           true,  // is_v_rowmajor
+                           mask.type,
+                           enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
+                           has_lse,
+                           has_dropout,
+                           false}; // do_fp8_static_quant
+}
+
+fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
+                                   bool has_dropout_randval,
+                                   const mask_info &mask,
+                                   // sizes
+                                   const int b,
+                                   const int seqlen_q,
+                                   const int seqlen_k,
+                                   const int h,
+                                   const int h_k,
+                                   const int d,
+                                   // device pointers
+                                   const at::Tensor q,
+                                   const at::Tensor k,
+                                   const at::Tensor v,
+                                   c10::optional<at::Tensor> &alibi_slopes_,
+                                   at::Tensor out,
+                                   at::Tensor softmax_lse,
+                                   at::Tensor dropout_randval,
+                                   float softmax_scale,
+                                   float p_dropout,
+                                   uint64_t drop_seed,
+                                   uint64_t drop_offset)
+{
+    // q: (batch_size, seqlen_q, nheads, d)
+    // k: (batch_size, seqlen_k, nheads_k, d)
+    // v: (batch_size, seqlen_k, nheads_k, d)
+    // o: (batch_size, seqlen_q, nheads, d)
+
+    // alibi_slopes:(batch_size, nheads) or (nhead)
+    // lse: (batch_size, nheads, seqlen_q)
+    // randval: (batch_size, nheads, seqlen_q, seqlen_k)
+
+    ck_tile::index_t stride_q = q.stride(1);
+    ck_tile::index_t stride_k = k.stride(1);
+    ck_tile::index_t stride_v = v.stride(1);
+    ck_tile::index_t stride_o = out.stride(1);
+    ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(2) : 0;
+
+    ck_tile::index_t nhead_stride_q = q.stride(2);
+    ck_tile::index_t nhead_stride_k = k.stride(2);
+    ck_tile::index_t nhead_stride_v = v.stride(2);
+    ck_tile::index_t nhead_stride_o = out.stride(2);
+    ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0;
+    ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0;
+
+    ck_tile::index_t batch_stride_q = q.stride(0);
+    ck_tile::index_t batch_stride_k = k.stride(0);
+    ck_tile::index_t batch_stride_v = v.stride(0);
+    ck_tile::index_t batch_stride_o = out.stride(0);
+
+    ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0;
+    ck_tile::index_t batch_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0;
+
+    void *alibi_slopes_ptr = nullptr;
+    ck_tile::index_t stride_alibi_slopes = 0;
+
+    if (alibi_slopes_.has_value()) {
+        auto alibi_slopes = alibi_slopes_.value();
+        CHECK_DEVICE(alibi_slopes);
+        TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
+        TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
+        alibi_slopes_ptr = alibi_slopes.data_ptr();
+        stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
+    }
+
+    return fmha_fwd_args{q.data_ptr(),
+                         k.data_ptr(),
+                         v.data_ptr(),
+                         alibi_slopes_ptr, // bias
+                         has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
+                         nullptr, // lse_acc
+                         nullptr, // o_acc
+                         has_lse ? softmax_lse.data_ptr() : nullptr,
+                         out.data_ptr(),
+                         nullptr, // seqstart_q
+                         nullptr, // seqstart_k
+                         nullptr,
+                         seqlen_q,
+                         seqlen_k,
+                         b,
+                         seqlen_q,      // max_seqlen_q
+                         d,             // hdim_q
+                         d,             // hdim_v
+                         h,             // nhead
+                         h_k,           // nhead_k
+                         1,             // num_splits
+                         softmax_scale, // scale_s
+                         1,             // scale_p
+                         1,             // scale_o
+                         stride_q,
+                         stride_k,
+                         stride_v,
+                         stride_alibi_slopes,
+                         stride_randval,
+                         0, // stride_o_acc,
+                         stride_o,
+                         nhead_stride_q,
+                         nhead_stride_k,
+                         nhead_stride_v,
+                         0, // nhead_stride_bias, FA without bias
+                         nhead_stride_randval,
+                         nhead_stride_lse,
+                         0, // nhead_stride_lse_acc
+                         0, // nhead_stride_o_acc
+                         nhead_stride_o,
+                         batch_stride_q,
+                         batch_stride_k,
+                         batch_stride_v,
+                         0, // batch_stride_bias, FA without bias
+                         batch_stride_randval,
+                         batch_stride_lse,
+                         0, // batch_stride_lse_acc
+                         0, // batch_stride_o_acc
+                         batch_stride_o,
+                         0, // split_stride_lse_acc
+                         0, // split_stride_o_acc
+                         mask.left,
+                         mask.right,
+                         static_cast<ck_tile::index_t>(mask.type),
+                         p_dropout,
+                         has_dropout_randval,
+                         {drop_seed, drop_offset}};
+}
+
+std::vector<at::Tensor>
+mha_fwd(at::Tensor &q,                            // batch_size x seqlen_q x num_heads x head_size
+        const at::Tensor &k,                      // batch_size x seqlen_k x num_heads_k x head_size
+        const at::Tensor &v,                      // batch_size x seqlen_k x num_heads_k x head_size
+        c10::optional<at::Tensor> &out_,          // batch_size x seqlen_q x num_heads x head_size
+        c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
+        const float p_dropout,
+        const float softmax_scale,
+        bool is_causal,
+        int window_size_left,
+        int window_size_right,
+        const float /*softcap*/,
+        const bool return_dropout_randval,
+        c10::optional<at::Generator> gen_)
+{
+    auto q_dtype = q.dtype();
+    TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
+                "FlashAttention only support fp16 and bf16 data type");
+
+    TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
+    TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
+
+    std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
+
+    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
+
+    TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
+    TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
+    TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
+
+    const auto sizes = q.sizes();
+
+    const int batch_size = sizes[0];
+    int seqlen_q = sizes[1];
+    int num_heads = sizes[2];
+    const int head_size_og = sizes[3];
+    const int seqlen_k = k.size(1);
+    const int num_heads_k = k.size(2);
+    TORCH_CHECK(batch_size > 0, "batch size must be positive");
+    TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256");
+    TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
+
+    if (window_size_left >= seqlen_k) { window_size_left = -1; }
+    if (window_size_right >= seqlen_k) { window_size_right = -1; }
+
+    // causal=true is the same as causal=false in this case
+    if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
+
+    mask_info mask;
+    if (is_causal) {
+        // Causal is the special case where window_size_right == 0 and window_size_left < 0.
+        window_size_right = 0;
+        std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
+        mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual
+    }
+    else if (window_size_left == -1 && window_size_right == -1) {
+        mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask
+    }
+    else {
+        // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
+        std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
+        mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local
+    }
+
+    // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
+    // H/t Daniel Haziza
+    const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
+    const int ngroups = num_heads / num_heads_k;
+    if (seqlenq_ngroups_swapped) {
+        q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
+        seqlen_q = ngroups;
+        num_heads = num_heads_k;
+    }
+
+    CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
+    CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
+    CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
+
+    at::Tensor q_padded, k_padded, v_padded;
+    if (head_size_og % 8 != 0) {
+        q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+        k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+        v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+    }
+    else {
+        q_padded = q;
+        k_padded = k;
+        v_padded = v;
+    }
+
+    at::Tensor out;
+    if (out_.has_value()) {
+        out = out_.value();
+        TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
+        CHECK_DEVICE(out);
+        TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
+        CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og);
+        if (seqlenq_ngroups_swapped) {
+            out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
+        }
+        if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
+    }
+    else {
+        out = torch::empty_like(q_padded);
+    }
+
+    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+    const int head_size_8x = round_multiple(head_size_og, 8);
+
+    // Otherwise the kernel will be launched from cuda:0 device
+    // Cast to char to avoid compiler warning about narrowing
+    at::cuda::CUDAGuard device_guard{(char)q.get_device()};
+
+    auto opts = q.options();
+    bool has_lse = true;
+    bool has_dropout = p_dropout > 0.0f;
+
+    at::Tensor softmax_lse;
+    // TODO - check gradient, only training require lse
+    softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(torch::kFloat32));
+
+    at::Tensor p;
+    if (return_dropout_randval) {
+        TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0");
+        p = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts.dtype(torch::kUInt8));
+    }
+
+    uint64_t drop_seed = 1, drop_offset = 0;
+    int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
+    auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
+    auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
+
+    if (p_dropout > 0.0)  {
+        auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
+            gen_, at::cuda::detail::getDefaultCUDAGenerator());
+        // See Note [Acquire lock when using random generators]
+        std::lock_guard<std::mutex> lock(gen->mutex_);
+        auto philox_args = gen->philox_cuda_state(counter_offset);
+        std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
+    }
+
+    rng_state[0] = *(reinterpret_cast<int64_t*>(&drop_seed));
+    rng_state[1] = *(reinterpret_cast<int64_t*>(&drop_offset));
+
+    if (seqlen_k > 0) {
+        auto stream = at::cuda::getCurrentHIPStream().stream();
+        ck_tile::stream_config stream_config{stream};
+
+        auto traits =
+            get_ck_fmha_fwd_traits(mask, q_dtype_str, head_size_8x, has_dropout, has_lse, alibi_slopes_.has_value());
+
+        auto args =
+            get_ck_fmha_fwd_args(
+                has_lse,
+                return_dropout_randval,
+                mask,
+                batch_size,
+                seqlen_q,
+                seqlen_k,
+                num_heads,
+                num_heads_k,
+                head_size_8x,
+                q_padded,
+                k_padded,
+                v_padded,
+                alibi_slopes_,
+                out,
+                softmax_lse,
+                p,
+                softmax_scale,
+                p_dropout,
+                drop_seed,
+                drop_offset);
+
+        fmha_fwd(traits, args, stream_config);
+    }
+    else {
+        // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
+        out.zero_();
+        softmax_lse.fill_(std::numeric_limits<float>::infinity());
+    }
+
+    at::Tensor out_padded = out;
+    if (head_size_og % 8 != 0) {
+        out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
+        if (out_.has_value()) { out_.value().copy_(out); }
+    }
+
+    if (seqlenq_ngroups_swapped) {
+        out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
+        out_padded = out_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
+        q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
+        softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
+    }
+    return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
+}

+ 406 - 0
csrc/flash_attn_ck/mha_varlen_bwd.cpp

@@ -0,0 +1,406 @@
+/******************************************************************************
+ * Copyright (c) 2024, Tri Dao.
+ ******************************************************************************/
+
+#include "flash_common.hpp"
+
+#include "fmha_bwd.hpp"
+#include "mask.hpp"
+
+fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask,
+                                              std::string dtype,
+                                              int head_size,
+                                              bool has_dropout,
+                                              bool enable_alibi)
+{
+    return fmha_bwd_traits{head_size,
+                           head_size,
+                           dtype,
+                           true, // is_group_mode
+                           mask.type,
+                           enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
+                           false,    // has_dbias
+                           has_dropout};
+}
+
+fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
+                                          // sizes
+                                          const int b,
+                                          const int max_seqlen_q,
+                                          const int max_seqlen_k,
+                                          const int h,
+                                          const int h_k,
+                                          const int hdim,
+                                          // device pointers
+                                          const at::Tensor q,
+                                          const at::Tensor k,
+                                          const at::Tensor v,
+                                          const at::Tensor seqlens_q,
+                                          const at::Tensor seqlens_k,
+                                          c10::optional<at::Tensor> &alibi_slopes_,
+                                          const at::Tensor out,
+                                          const at::Tensor softmax_lse,
+                                          const at::Tensor dout,
+                                          at::Tensor d,
+                                          at::Tensor dq,
+                                          at::Tensor dk,
+                                          at::Tensor dv,
+                                          float softmax_scale,
+                                          float p_dropout,
+                                          uint64_t drop_seed,
+                                          uint64_t drop_offset)
+{
+    // q: (total_q, nheads, hdim)
+    // k: (total_k, nheads_k, hdim)
+    // v: (total_k, nheads_k, hdim)
+    // o: (total_q, nheads, hdim)
+    // dq: (total_q, nheads, hdim)
+    // dk_expanded: (total_k, nheads, hdim)
+    // dv_expanded: (total_k, nheads, hdim)
+    // do: (total_q, nheads, hdim)
+
+    // alibi_slopes:(batch_size, nheads) or (nhead)
+    // lse: (batch_size, nheads, max_seqlen_q)
+    // d: (batch_size, nheads, max_seqlen_q)
+
+    ck_tile::index_t total_q = q.size(0);
+    ck_tile::index_t total_k = k.size(0);
+
+    ck_tile::index_t stride_q = q.stride(0);
+    ck_tile::index_t stride_k = k.stride(0);
+    ck_tile::index_t stride_v = v.stride(0);
+    ck_tile::index_t stride_o = out.stride(0);
+    ck_tile::index_t stride_do = dout.stride(0);
+    ck_tile::index_t stride_dk = dk.stride(0);
+    ck_tile::index_t stride_dv = dv.stride(0);
+
+    ck_tile::index_t nhead_stride_q = q.stride(1);
+    ck_tile::index_t nhead_stride_k = k.stride(1);
+    ck_tile::index_t nhead_stride_v = v.stride(1);
+    ck_tile::index_t nhead_stride_o = out.stride(1);
+    ck_tile::index_t nhead_stride_do = dout.stride(1);
+    ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1);
+
+    ck_tile::index_t batch_stride_q = 0;
+    ck_tile::index_t batch_stride_k = 0;
+    ck_tile::index_t batch_stride_v = 0;
+    ck_tile::index_t batch_stride_o = 0;
+    ck_tile::index_t batch_stride_do = 0;
+    ck_tile::index_t batch_stride_lse = softmax_lse.stride(0);;
+    ck_tile::index_t batch_stride_dk = 0;
+    ck_tile::index_t batch_stride_dv = 0;
+
+    float p_undrop = 1.0 - p_dropout;
+
+    void *alibi_slopes_ptr = nullptr;
+    ck_tile::index_t stride_alibi_slopes = 0;
+
+    if (alibi_slopes_.has_value()) {
+        auto alibi_slopes = alibi_slopes_.value();
+        CHECK_DEVICE(alibi_slopes);
+        TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
+        TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
+        alibi_slopes_ptr = alibi_slopes.data_ptr();
+        stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
+    }
+
+    return fmha_bwd_args{q.data_ptr(),
+                         k.data_ptr(),
+                         v.data_ptr(),
+                         alibi_slopes_ptr, // bias
+                         out.data_ptr(),
+                         softmax_lse.data_ptr(),
+                         dout.data_ptr(),
+                         d.data_ptr(),
+                         nullptr, // rand_val
+                         dq.data_ptr(),
+                         dk.data_ptr(),
+                         dv.data_ptr(),
+                         nullptr, // dbias
+                         seqlens_q.data_ptr(), // seqstart_q
+                         seqlens_k.data_ptr(), // seqstart_k
+                         nullptr, // seqlen_k_ptr
+                         total_q,
+                         total_k,
+                         b,
+                         max_seqlen_q, // max_seqlen_q
+                         max_seqlen_k, // max_seqlen_k
+                         hdim, // hdim_q
+                         hdim, // hdim_v
+                         h, // nhead
+                         h_k, // nhead_k
+                         softmax_scale,
+                         stride_q,
+                         stride_k,
+                         stride_v,
+                         stride_alibi_slopes,
+                         stride_o,
+                         0, // stride_randval
+                         stride_do,
+                         stride_dk,
+                         stride_dv,
+                         0, // stride_dbias, FA without bias
+                         nhead_stride_q,
+                         nhead_stride_k,
+                         nhead_stride_v,
+                         0, // nhead_stride_bias, FA without bias
+                         nhead_stride_o,
+                         0, // nhead_stride_randval
+                         nhead_stride_do,
+                         nhead_stride_lse,
+                         0, // nhead_stride_dbias, FA without dbias
+                         batch_stride_q,
+                         batch_stride_k,
+                         batch_stride_v,
+                         0  , // batch_stride_bias, FA without bias
+                         batch_stride_o,
+                         0, // batch_stride_randval
+                         batch_stride_do,
+                         batch_stride_lse,
+                         batch_stride_dk,
+                         batch_stride_dv,
+                         0  , // batch_stride_dbias, FA without dbias
+                         mask.left,
+                         mask.right,
+                         static_cast<ck_tile::index_t>(mask.type),
+                         p_dropout,
+                         p_undrop,
+                         false, // s_randval
+                         {drop_seed, drop_offset}};
+}
+
+std::vector<at::Tensor>
+mha_varlen_bwd(const at::Tensor &dout,                   // total_q x num_heads x head_size
+               const at::Tensor &q,                      // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
+               const at::Tensor &k,                      // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
+               const at::Tensor &v,                      // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
+               const at::Tensor &out,                    // total_q x num_heads x head_size
+               const at::Tensor &softmax_lse,            // b x h x s   softmax logsumexp
+               c10::optional<at::Tensor> &dq_,           // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
+               c10::optional<at::Tensor> &dk_,           // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
+               c10::optional<at::Tensor> &dv_,           // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
+               const at::Tensor &cu_seqlens_q,           // b+1
+               const at::Tensor &cu_seqlens_k,           // b+1
+               c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
+               const int max_seqlen_q,
+               const int max_seqlen_k, // max sequence length to choose the kernel
+               const float p_dropout,  // probability to drop
+               const float softmax_scale,
+               const bool zero_tensors,
+               const bool is_causal,
+               int window_size_left,
+               int window_size_right,
+               const float /*softcap*/,
+               const bool deterministic,
+               c10::optional<at::Generator> gen_,
+               c10::optional<at::Tensor> &rng_state)
+{
+#ifdef FLASHATTENTION_DISABLE_BACKWARD
+    TORCH_CHECK(false, "This flash attention build does not support backward.");
+#endif
+    if (is_causal) { window_size_right = 0; }
+
+    bool is_dropout = p_dropout > 0.0;
+    auto stream = at::cuda::getCurrentCUDAStream().stream();
+
+    auto q_dtype = q.dtype();
+    TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
+                "FlashAttention only support fp16 and bf16 data type");
+
+    TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
+    TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
+    TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
+    TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
+    TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
+    TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
+
+    std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
+
+    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
+    CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
+    CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
+
+    TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
+    TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
+    TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
+    TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
+    TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
+    CHECK_CONTIGUOUS(cu_seqlens_q);
+    CHECK_CONTIGUOUS(cu_seqlens_k);
+
+    const auto sizes = q.sizes();
+
+    const int total_q = sizes[0];
+    const int batch_size = cu_seqlens_q.numel() - 1;
+    const int num_heads = sizes[1];
+    const int head_size_og = dout.size(2);
+    const int head_size_8x = sizes[2];
+    const int total_k = k.size(0);
+    const int num_heads_k = k.size(1);
+    TORCH_CHECK(batch_size > 0, "batch size must be positive");
+    TORCH_CHECK(head_size_8x % 8 == 0, "head_size should be a multiple of 8");
+    TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128");
+    TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
+
+    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+    TORCH_CHECK(head_size_8x == round_multiple(head_size_og, 8), "head_size_8x must be head_size_og rounded to a multiple of 8");
+
+    if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
+    if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
+
+    mask_info mask;
+    if (is_causal) {
+        std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
+        mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual
+    }
+    else if (window_size_left == -1 && window_size_right == -1) {
+        mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask
+    }
+    else {
+        // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
+        std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
+        mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local
+    }
+
+    // q, k, v, out had been padded in mha_fwd
+    // dq_, dk_, dv_ are also padded tensor
+    CHECK_SHAPE(q, total_q, num_heads, head_size_8x);
+    CHECK_SHAPE(k, total_k, num_heads_k, head_size_8x);
+    CHECK_SHAPE(v, total_k, num_heads_k, head_size_8x);
+    CHECK_SHAPE(out, total_q, num_heads, head_size_8x);
+    CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
+    CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
+    CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
+
+    at::Tensor dq, dk, dv;
+    if (dq_.has_value()) {
+        dq = dq_.value();
+        TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
+        CHECK_DEVICE(dq);
+        TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
+        CHECK_SHAPE(dq, total_q, num_heads, head_size_8x);
+    } else {
+        dq = torch::empty_like(q);
+    }
+    if (dk_.has_value()) {
+        dk = dk_.value();
+        TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
+        CHECK_DEVICE(dk);
+        TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
+        CHECK_SHAPE(dk, total_k, num_heads_k, head_size_8x);
+    } else {
+        dk = torch::empty_like(k);
+    }
+    if (dv_.has_value()) {
+        dv = dv_.value();
+        TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
+        CHECK_DEVICE(dv);
+        TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
+        CHECK_SHAPE(dv, total_k, num_heads_k, head_size_8x);
+    } else {
+        dv = torch::empty_like(v);
+    }
+
+    at::Tensor dout_padded;
+    if (head_size_og % 8 != 0) {
+        dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+    } else {
+        dout_padded = dout;
+    }
+
+
+    // Cast to char to avoid compiler warning about narrowing
+    at::cuda::CUDAGuard device_guard{(char)q.get_device()};
+
+    auto opts = q.options();
+    auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
+    // TODO - CK does not support dq_accum
+
+    at::Tensor dk_expanded, dv_expanded;
+    if (num_heads_k != num_heads) {  // MQA / GQA
+        dk_expanded = torch::empty({total_k, num_heads, head_size_8x}, opts);
+        dv_expanded = torch::empty({total_k, num_heads, head_size_8x}, opts);
+    } else {
+        dk_expanded = dk;
+        dv_expanded = dv;
+    }
+
+    if(zero_tensors) {
+        dq.zero_();
+        dk_expanded.zero_();
+        dv_expanded.zero_();
+        softmax_d.zero_();
+    }
+
+    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
+        gen_, at::cuda::detail::getDefaultCUDAGenerator());
+
+    uint64_t drop_seed = 1, drop_offset = 0;
+    int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
+
+    if (rng_state.has_value()) {
+        uint64_t* d = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
+        drop_seed = d[0];
+        drop_offset = d[1];
+    } else if(is_dropout) {
+        // See Note [Acquire lock when using random generators]
+        std::lock_guard<std::mutex> lock(gen->mutex_);
+        auto philox_args = gen->philox_cuda_state(counter_offset);
+        std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
+    }
+
+    if (max_seqlen_q > 0) {
+        ck_tile::stream_config stream_config{stream};
+        dq.zero_(); // ck use atomic operation on dq
+
+        auto traits =
+            get_ck_fmha_varlen_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value());
+
+        auto args =
+            get_ck_fmha_varlen_bwd_args(
+                mask,
+                batch_size,
+                max_seqlen_q,
+                max_seqlen_k,
+                num_heads,
+                num_heads_k,
+                head_size_8x,
+                q,
+                k,
+                v,
+                cu_seqlens_q,
+                cu_seqlens_k,
+                alibi_slopes_,
+                out,
+                softmax_lse,
+                dout_padded,
+                softmax_d,
+                dq,
+                dk_expanded,
+                dv_expanded,
+                softmax_scale,
+                p_dropout,
+                drop_seed,
+                drop_offset);
+
+        fmha_bwd(traits, args, stream_config);
+    } else {
+        // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
+        dk_expanded.zero_();
+        dv_expanded.zero_();
+        softmax_d.zero_();
+    }
+
+    // For MQA/GQA we need to sum dK and dV across the groups
+    if (num_heads_k != num_heads) {
+        at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {2});
+        at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {2});
+    }
+    if (head_size_og % 8 != 0) {
+        dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
+        dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
+        dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
+    }
+
+    return { dq, dk, dv, softmax_d };
+}

+ 371 - 0
csrc/flash_attn_ck/mha_varlen_fwd.cpp

@@ -0,0 +1,371 @@
+/******************************************************************************
+ * Copyright (c) 2024, Tri Dao.
+ ******************************************************************************/
+
+#include "flash_common.hpp"
+
+#include "fmha_fwd.hpp"
+#include "mask.hpp"
+
+fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,
+                                              std::string dtype,
+                                              int head_size,
+                                              bool has_dropout,
+                                              bool has_lse,
+                                              bool enable_alibi)
+{
+    return fmha_fwd_traits{head_size,
+                           head_size,
+                           dtype,
+                           true, // is_group_mode
+                           true, // is_v_rowmajor
+                           mask.type,
+                           enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
+                           has_lse,
+                           has_dropout,
+                           false}; // do_fp8_static_quant
+}
+
+fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
+                                          bool has_dropout_randval,
+                                          const mask_info &mask,
+                                          // sizes
+                                          const int b,
+                                          const int max_seqlen_q,
+                                          const int h,
+                                          const int h_k,
+                                          const int d,
+                                          // device pointers
+                                          const at::Tensor q,
+                                          const at::Tensor k,
+                                          const at::Tensor v,
+                                          const at::Tensor seqlens_q,
+                                          const at::Tensor seqlens_k,
+                                          c10::optional<at::Tensor> &alibi_slopes_,
+                                          at::Tensor out,
+                                          at::Tensor softmax_lse,
+                                          at::Tensor dropout_randval,
+                                          float softmax_scale,
+                                          float p_dropout,
+                                          uint64_t drop_seed,
+                                          uint64_t drop_offset)
+{
+    // q: (total_q, nheads, d)
+    // k: (total_k, nheads_k, d)
+    // v: (total_k, nheads_k, d)
+    // o: (total_q, nheads, d)
+
+    // alibi_slopes:(batch, nheads) or (nhead)
+    // lse: (batch, nheads, max_seqlen_q)
+    // randval: (nheads, total_q, max_seqlen_k)
+
+    ck_tile::index_t total_q = q.size(0);
+    ck_tile::index_t total_k = k.size(0);
+
+    ck_tile::index_t stride_q = q.stride(0);
+    ck_tile::index_t stride_k = k.stride(0);
+    ck_tile::index_t stride_v = v.stride(0);
+    ck_tile::index_t stride_o = out.stride(0);
+    ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0;
+
+    ck_tile::index_t nhead_stride_q = q.stride(1);
+    ck_tile::index_t nhead_stride_k = k.stride(1);
+    ck_tile::index_t nhead_stride_v = v.stride(1);
+    ck_tile::index_t nhead_stride_o = out.stride(1);
+    ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0;
+    ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0;
+
+    ck_tile::index_t batch_stride_q = 0;
+    ck_tile::index_t batch_stride_k = 0;
+    ck_tile::index_t batch_stride_v = 0;
+    ck_tile::index_t batch_stride_o = 0;
+
+    ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0;
+    ck_tile::index_t batch_stride_randval = 0;
+
+    void *alibi_slopes_ptr = nullptr;
+    ck_tile::index_t stride_alibi_slopes = 0;
+
+    if (alibi_slopes_.has_value()) {
+        auto alibi_slopes = alibi_slopes_.value();
+        CHECK_DEVICE(alibi_slopes);
+        TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
+        TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
+        alibi_slopes_ptr = alibi_slopes.data_ptr();
+        stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
+    }
+
+    return fmha_fwd_args{q.data_ptr(),
+                         k.data_ptr(),
+                         v.data_ptr(),
+                         alibi_slopes_ptr, // bias
+                         has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
+                         nullptr, // lse_acc
+                         nullptr, // o_acc
+                         has_lse ? softmax_lse.data_ptr() : nullptr,
+                         out.data_ptr(),
+                         seqlens_q.data_ptr(), // seqstart_q
+                         seqlens_k.data_ptr(), // seqstart_k
+                         nullptr,              // seqlen_kpads
+                         total_q,
+                         total_k,
+                         b,
+                         max_seqlen_q,
+                         d,             // hdim_q
+                         d,             // hdim_v
+                         h,             // nhead
+                         h_k,           // nhead_k
+                         1,             // num_splits
+                         softmax_scale, // scale_s
+                         1,             // scale_p
+                         1,             // scale_o
+                         stride_q,
+                         stride_k,
+                         stride_v,
+                         stride_alibi_slopes,
+                         stride_randval,
+                         0, // stride_o_acc,
+                         stride_o,
+                         nhead_stride_q,
+                         nhead_stride_k,
+                         nhead_stride_v,
+                         0, // nhead_stride_bias, FA without bias
+                         nhead_stride_randval,
+                         nhead_stride_lse,
+                         0, // nhead_stride_lse_acc
+                         0, // nhead_stride_o_acc
+                         nhead_stride_o,
+                         batch_stride_q,
+                         batch_stride_k,
+                         batch_stride_v,
+                         0, // batch_stride_bias, FA without bias
+                         batch_stride_randval,
+                         batch_stride_lse,
+                         0, // batch_stride_lse_acc
+                         0, // batch_stride_o_acc
+                         batch_stride_o,
+                         0, // split_stride_lse_acc
+                         0, // split_stride_o_acc
+                         mask.left,
+                         mask.right,
+                         static_cast<ck_tile::index_t>(mask.type),
+                         p_dropout,
+                         has_dropout_randval,
+                         {drop_seed, drop_offset}};
+}
+
+std::vector<at::Tensor>
+mha_varlen_fwd(at::Tensor &q,                   // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
+               const at::Tensor &k,             // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
+               const at::Tensor &v,             // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
+               c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
+               const at::Tensor &cu_seqlens_q,  // b+1
+               const at::Tensor &cu_seqlens_k,  // b+1
+               c10::optional<at::Tensor> & /*seqused_k*/,
+               c10::optional<const at::Tensor> &/*leftpad_k_*/, // batch_size
+               c10::optional<at::Tensor> &block_table_,  // batch_size x max_num_blocks_per_seq
+               c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
+               int max_seqlen_q,
+               const int max_seqlen_k,
+               const float p_dropout,
+               const float softmax_scale,
+               const bool zero_tensors,
+               bool is_causal,
+               int window_size_left,
+               int window_size_right,
+               const float /*softcap*/,
+               const bool return_dropout_randval,
+               c10::optional<at::Generator> gen_)
+{
+    auto q_dtype = q.dtype();
+    TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
+                "FlashAttention only support fp16 and bf16 data type");
+
+    TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
+    TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
+    TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
+    TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
+
+    std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
+
+    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
+    CHECK_DEVICE(cu_seqlens_q);
+    CHECK_DEVICE(cu_seqlens_k);
+
+    // TODO - Support paged_KV
+    const bool paged_KV = block_table_.has_value();
+    TORCH_CHECK(!paged_KV, "CK does not support paged_KV yet");
+
+    TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
+    TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
+    TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
+    CHECK_CONTIGUOUS(cu_seqlens_q);
+    CHECK_CONTIGUOUS(cu_seqlens_k);
+
+    const auto sizes = q.sizes();
+
+    const int batch_size = cu_seqlens_q.numel() - 1;
+    int num_heads = sizes[1];
+    const int head_size_og = sizes[2];
+    const int num_heads_k = k.size(1);
+
+    const int max_num_blocks_per_seq = 0;
+    const int num_blocks = 0;
+
+    if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }  // causal=true is the same as causal=false in this case
+
+    // TODO
+    // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
+    // H/t Daniel Haziza
+
+    const int total_q = q.size(0);
+    const int total_k = k.size(0);
+
+    TORCH_CHECK(batch_size > 0, "batch size must be postive");
+    TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256");
+    TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
+
+    if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
+    if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
+
+    mask_info mask;
+
+    if (is_causal) {
+        // Causal is the special case where window_size_right == 0 and window_size_left < 0.
+        window_size_right = 0;
+        std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
+        mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual
+    }
+    else if (window_size_left == -1 && window_size_right == -1) {
+        mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask
+    }
+    else {
+        // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
+        std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
+        mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local
+    }
+
+    CHECK_SHAPE(q, total_q, num_heads, head_size_og);
+    CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
+    CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
+    CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
+    CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
+
+    at::Tensor q_padded, k_padded, v_padded;
+    if (head_size_og % 8 != 0) {
+        q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+        k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+        v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+    }
+    else {
+        q_padded = q;
+        k_padded = k;
+        v_padded = v;
+    }
+
+    at::Tensor out;
+    if (out_.has_value()) {
+        out = out_.value();
+        TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
+        CHECK_DEVICE(out);
+        TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
+        CHECK_SHAPE(out, total_q, num_heads, head_size_og);
+
+        if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
+    }
+    else {
+        out = torch::empty_like(q_padded);
+    }
+
+    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+    const int head_size_8x = round_multiple(head_size_og, 8);
+
+    // Otherwise the kernel will be launched from cuda:0 device
+    // Cast to char to avoid compiler warning about narrowing
+    at::cuda::CUDAGuard device_guard{(char)q.get_device()};
+
+    auto opts = q.options();
+    bool has_lse = true;
+    bool has_dropout = p_dropout > 0.0f;
+
+    at::Tensor softmax_lse;
+    // TODO - check gradient, only training require lse
+    softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(torch::kFloat32));
+
+    at::Tensor p;
+    if (return_dropout_randval) {
+        TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0");
+        p = torch::empty({num_heads, total_q, max_seqlen_k}, opts.dtype(torch::kUInt8));
+    }
+
+    if (zero_tensors)
+    {
+        out.zero_();
+        softmax_lse.fill_(-std::numeric_limits<float>::infinity());
+        if (return_dropout_randval) {p.zero_();}
+    }
+
+    uint64_t drop_seed = 1, drop_offset = 0;
+    int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
+    auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
+    auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
+
+    if (p_dropout > 0.0)  {
+        auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
+            gen_, at::cuda::detail::getDefaultCUDAGenerator());
+        // See Note [Acquire lock when using random generators]
+        std::lock_guard<std::mutex> lock(gen->mutex_);
+        auto philox_args = gen->philox_cuda_state(counter_offset);
+        std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
+    }
+
+    rng_state[0] = *(reinterpret_cast<int64_t*>(&drop_seed));
+    rng_state[1] = *(reinterpret_cast<int64_t*>(&drop_offset));
+
+    if (max_seqlen_k > 0) {
+        auto stream = at::cuda::getCurrentHIPStream().stream();
+        ck_tile::stream_config stream_config{stream};
+
+        auto traits =
+            get_ck_fmha_varlen_fwd_traits(mask, q_dtype_str, head_size_8x, has_dropout, has_lse, alibi_slopes_.has_value());
+
+        auto args =
+            get_ck_fmha_varlen_fwd_args(
+                has_lse,
+                return_dropout_randval,
+                mask,
+                batch_size,
+                max_seqlen_q,
+                num_heads,
+                num_heads_k,
+                head_size_8x,
+                q_padded,
+                k_padded,
+                v_padded,
+                cu_seqlens_q,
+                cu_seqlens_k,
+                alibi_slopes_,
+                out,
+                softmax_lse,
+                p,
+                softmax_scale,
+                p_dropout,
+                drop_seed,
+                drop_offset);
+
+        fmha_fwd(traits, args, stream_config);
+    }
+    else {
+        // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
+        out.zero_();
+        softmax_lse.fill_(std::numeric_limits<float>::infinity());
+    }
+
+    at::Tensor out_padded = out;
+    if (head_size_og % 8 != 0) {
+        out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
+        if (out_.has_value()) { out_.value().copy_(out); }
+    }
+
+    return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
+}

+ 155 - 13
setup.py

@@ -5,6 +5,8 @@ import warnings
 import os
 import re
 import ast
+import glob
+import shutil
 from pathlib import Path
 from packaging.version import parse, Version
 import platform
@@ -22,6 +24,8 @@ from torch.utils.cpp_extension import (
     CppExtension,
     CUDAExtension,
     CUDA_HOME,
+    ROCM_HOME,
+    IS_HIP_EXTENSION,
 )
 
 
@@ -32,6 +36,19 @@ with open("README.md", "r", encoding="utf-8") as fh:
 # ninja build does not work unless include_dirs are abs path
 this_dir = os.path.dirname(os.path.abspath(__file__))
 
+BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto")
+
+if BUILD_TARGET == "auto":
+    if IS_HIP_EXTENSION:
+        IS_ROCM = True
+    else:
+        IS_ROCM = False
+else:
+    if BUILD_TARGET == "cuda":
+        IS_ROCM = False
+    elif BUILD_TARGET == "rocm":
+        IS_ROCM = True
+
 PACKAGE_NAME = "flash_attn"
 
 BASE_WHEEL_URL = (
@@ -82,19 +99,47 @@ def check_if_cuda_home_none(global_option: str) -> None:
     )
 
 
+def check_if_rocm_home_none(global_option: str) -> None:
+    if ROCM_HOME is not None:
+        return
+    # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary
+    # in that case.
+    warnings.warn(
+        f"{global_option} was requested, but hipcc was not found."
+    )
+
+
 def append_nvcc_threads(nvcc_extra_args):
     nvcc_threads = os.getenv("NVCC_THREADS") or "4"
     return nvcc_extra_args + ["--threads", nvcc_threads]
 
 
+def rename_cpp_to_cu(cpp_files):
+    for entry in cpp_files:
+        shutil.copy(entry, os.path.splitext(entry)[0] + ".cu")
+
+
+def validate_and_update_archs(archs):
+    # List of allowed architectures
+    allowed_archs = ["native", "gfx90a", "gfx940", "gfx941", "gfx942"]
+
+    # Validate if each element in archs is in allowed_archs
+    assert all(
+        arch in allowed_archs for arch in archs
+    ), f"One of GPU archs of {archs} is invalid or not supported by Flash-Attention"
+
+
 cmdclass = {}
 ext_modules = []
 
 # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp
 # files included in the source distribution, in case the user compiles from source.
-subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
+if IS_ROCM:
+    subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"])
+else:
+    subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
 
-if not SKIP_CUDA_BUILD:
+if not SKIP_CUDA_BUILD and not IS_ROCM:
     print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
     TORCH_MAJOR = int(torch.__version__.split(".")[0])
     TORCH_MINOR = int(torch.__version__.split(".")[1])
@@ -250,6 +295,95 @@ if not SKIP_CUDA_BUILD:
             ],
         )
     )
+elif not SKIP_CUDA_BUILD and IS_ROCM:
+    ck_dir = "csrc/composable_kernel"
+
+    #use codegen get code dispatch
+    if not os.path.exists("./build"):
+        os.makedirs("build")
+
+    os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --output_dir build --receipt 2")
+    os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d bwd --output_dir build --receipt 2")
+
+    print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
+    TORCH_MAJOR = int(torch.__version__.split(".")[0])
+    TORCH_MINOR = int(torch.__version__.split(".")[1])
+
+    # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
+    # See https://github.com/pytorch/pytorch/pull/70650
+    generator_flag = []
+    torch_dir = torch.__path__[0]
+    if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
+        generator_flag = ["-DOLD_GENERATOR_PATH"]
+
+    check_if_rocm_home_none("flash_attn")
+    cc_flag = []
+
+    archs = os.getenv("GPU_ARCHS", "native").split(";")
+    validate_and_update_archs(archs)
+
+    cc_flag = [f"--offload-arch={arch}" for arch in archs]
+
+    # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
+    # torch._C._GLIBCXX_USE_CXX11_ABI
+    # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
+    if FORCE_CXX11_ABI:
+        torch._C._GLIBCXX_USE_CXX11_ABI = True
+
+    sources = ["csrc/flash_attn_ck/flash_api.cpp",
+               "csrc/flash_attn_ck/mha_bwd.cpp",
+               "csrc/flash_attn_ck/mha_fwd.cpp",
+               "csrc/flash_attn_ck/mha_varlen_bwd.cpp",
+               "csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob(
+        f"build/fmha_*wd*.cpp"
+    )
+
+    rename_cpp_to_cu(sources)
+
+    renamed_sources = ["csrc/flash_attn_ck/flash_api.cu",
+                       "csrc/flash_attn_ck/mha_bwd.cu",
+                       "csrc/flash_attn_ck/mha_fwd.cu",
+                       "csrc/flash_attn_ck/mha_varlen_bwd.cu",
+                       "csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu")
+    extra_compile_args = {
+        "cxx": ["-O3", "-std=c++17"] + generator_flag,
+        "nvcc":
+            [
+                "-O3","-std=c++17",
+                "-mllvm", "-enable-post-misched=0",
+                "-DCK_TILE_FMHA_FWD_FAST_EXP2=1",
+                "-fgpu-flush-denormals-to-zero",
+                "-DCK_ENABLE_BF16",
+                "-DCK_ENABLE_BF8",
+                "-DCK_ENABLE_FP16",
+                "-DCK_ENABLE_FP32",
+                "-DCK_ENABLE_FP64",
+                "-DCK_ENABLE_FP8",
+                "-DCK_ENABLE_INT8",
+                "-DCK_USE_XDL",
+                "-DUSE_PROF_API=1",
+                "-D__HIP_PLATFORM_HCC__=1",
+                # "-DFLASHATTENTION_DISABLE_BACKWARD",
+            ]
+            + generator_flag
+            + cc_flag
+        ,
+    }
+
+    include_dirs = [
+        Path(this_dir) / "csrc" / "composable_kernel" / "include",
+        Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include",
+        Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha",
+    ]
+
+    ext_modules.append(
+        CUDAExtension(
+            name="flash_attn_2_cuda",
+            sources=renamed_sources,
+            extra_compile_args=extra_compile_args,
+            include_dirs=include_dirs,
+        )
+    )
 
 
 def get_package_version():
@@ -264,25 +398,33 @@ def get_package_version():
 
 
 def get_wheel_url():
-    # Determine the version numbers that will be used to determine the correct wheel
-    # We're using the CUDA version used to build torch, not the one currently installed
-    # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
-    torch_cuda_version = parse(torch.version.cuda)
     torch_version_raw = parse(torch.__version__)
-    # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
-    # to save CI time. Minor versions should be compatible.
-    torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
     python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
     platform_name = get_platform()
     flash_version = get_package_version()
-    # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
-    cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
     torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
     cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
 
-    # Determine wheel URL based on CUDA version, torch version, python version and OS
-    wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
+    if IS_ROCM:
+        torch_hip_version = parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+'))
+        hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}"
+        wheel_filename = f"{PACKAGE_NAME}-{flash_version}+rocm{hip_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
+    else:
+        # Determine the version numbers that will be used to determine the correct wheel
+        # We're using the CUDA version used to build torch, not the one currently installed
+        # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
+        torch_cuda_version = parse(torch.version.cuda)
+        # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
+        # to save CI time. Minor versions should be compatible.
+        torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
+        # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
+        cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
+
+        # Determine wheel URL based on CUDA version, torch version, python version and OS
+        wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
+
     wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename)
+
     return wheel_url, wheel_filename
 
 

+ 754 - 0
tests/test_flash_attn_ck.py

@@ -0,0 +1,754 @@
+import math
+
+import pytest
+import torch
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from flash_attn import (
+    flash_attn_func,
+    flash_attn_kvpacked_func,
+    flash_attn_qkvpacked_func,
+    flash_attn_varlen_func,
+    flash_attn_varlen_kvpacked_func,
+    flash_attn_varlen_qkvpacked_func,
+)
+
+from test_flash_attn import (
+    attn_bias_from_alibi_slopes,
+    convert_flash_attn_S_to_softmax,
+    generate_qkv,
+    generate_random_padding_mask,
+    attention_ref,
+    attention_kvpacked_ref,
+    attention_qkvpacked_ref,
+)
+
+def is_bwd_hdim_supported(d):
+    return d <= 128 and d % 2 == 0
+
+
+def ck_randval_to_dropout_mask(randval, p):
+    # If p = 0.3, randval in 255 * (0.7, 1.0] will be dropout
+    # randval in 255 * [0, 0.7] will be kept
+    # If return dropout_mask >=0, value will be kept
+    return torch.floor(255.0 * (1 - p) - randval)
+
+
+def pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen_q_rounded, seqlen_k_rounded):
+    """ pad + rearrange [nheads, total_q, max_seqlen_k] into [b, nheads, seqlen_q_rounded, seqlen_k_rounded]
+    Arguments:
+        S_dmask: (nheads, total_q, max_seqlen_k)
+        cu_seqlens_q: (b + 1)
+    Output:
+        S_dmask: (b, nheads, seqlen_q_rounded, seqlen_k_rounded)
+    """
+    batch_size = cu_seqlens_q.numel() - 1
+    seqlens_q = torch.roll(cu_seqlens_q, shifts = -1) - cu_seqlens_q
+    seqlens_q = seqlens_q[0:batch_size].tolist()
+    S_dmask = torch.split(S_dmask, seqlens_q, dim=1)
+    # [(nheads, seqlen_q0, max_seqlen_k), (nheads, seqlen_q1, max_seqlen_k), ..., (nheads, seqlen_qb, max_seqlen_k)]
+    masks = ()
+    for mask in S_dmask:
+        # (nheads, seqlen_qi, max_seqlen_k) -> (nheads, seqlen_q_rounded, seqlen_k_rounded)
+        mask = F.pad(mask, (0, seqlen_k_rounded - mask.shape[2], 0, seqlen_q_rounded - mask.shape[1], 0, 0)).unsqueeze(1)
+        masks = masks + (mask, )
+    S_dmask = torch.cat(masks, dim=1)
+
+    S_dmask = S_dmask.transpose(0, 1)
+    return S_dmask
+
+
+@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
+@pytest.mark.parametrize("deterministic", [False])
+@pytest.mark.parametrize("alibi", [False, True])
+@pytest.mark.parametrize("local", [False, True])
+@pytest.mark.parametrize("causal", [False, True])
+@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
+@pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048])
+@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
+def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
+    if d > 256:
+        pytest.skip()
+
+    device = "cuda"
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 4
+    nheads = 9
+    window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
+
+    qkv = torch.randn(
+        batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
+    )
+
+    if alibi:
+        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
+        attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal)
+    else:
+        alibi_slopes, attn_bias = None, None
+    out, lse, S_dmask = flash_attn_qkvpacked_func(
+        qkv,
+        dropout_p,
+        causal=causal,
+        window_size=window_size,
+        alibi_slopes=alibi_slopes,
+        deterministic=deterministic,
+        return_attn_probs=True,
+    )
+    if dropout_p > 0.0:
+        # TODO - move to c++ mha_varlen_fwd()
+        S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)
+        S_dmask_converted = convert_flash_attn_S_to_softmax(
+            S_dmask,
+            seqlen,
+            seqlen,
+            None,
+            None,
+            d,
+            dropout_p > 0.0,
+            causal=causal,
+            window_size=window_size,
+        )
+        dropout_mask = S_dmask_converted >= 0
+        # CK does not return P. Hence, we don't test the attn here.
+    else:
+        dropout_mask = None
+
+    out_ref, attn_ref = attention_qkvpacked_ref(
+        qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size
+    )
+    out_pt, attn_pt = attention_qkvpacked_ref(
+        qkv,
+        None,
+        attn_bias,
+        dropout_p,
+        dropout_mask,
+        causal=causal,
+        window_size=window_size,
+        upcast=False,
+        reorder_ops=True,
+    )
+
+    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
+    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
+    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
+    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
+
+    # Check that FlashAttention's numerical error is at most twice the numerical error
+    # of a Pytorch implementation.
+    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
+
+    g = torch.randn_like(out)
+    if is_bwd_hdim_supported(d):
+        (dqkv,) = torch.autograd.grad(out, qkv, g)
+        (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
+        (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
+        print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
+        print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
+        print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
+        print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
+        print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
+        print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
+        print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
+        print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
+
+        # TODO - use 10 times to check, wait for ck to change dq type to f32
+        assert (dqkv - dqkv_ref).abs().max().item() <= 10 * (dqkv_pt - dqkv_ref).abs().max().item()
+
+
+@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
+@pytest.mark.parametrize("deterministic", [False])
+@pytest.mark.parametrize("alibi", [False, True])
+@pytest.mark.parametrize("local", [False, True])
+@pytest.mark.parametrize("causal", [False, True])
+@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
+@pytest.mark.parametrize("seqlen", [97, 128, 200, 257, 384, 512, 768, 1025, 2048])
+@pytest.mark.parametrize("dropout_p", [0, 0.17])
+def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
+    if d > 256:
+        pytest.skip()
+
+    device = "cuda"
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 5
+    nheads = 6
+    window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
+    qkv = torch.randn(
+        batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
+    )
+
+    key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random")
+    # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
+    if alibi:
+        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
+        attn_bias = attn_bias_from_alibi_slopes(
+            alibi_slopes, seqlen, seqlen, key_padding_mask, key_padding_mask, causal=causal
+        )
+    else:
+        alibi_slopes, attn_bias = None, None
+
+    qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
+        *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True
+    )
+
+    out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func(
+        qkv_unpad,
+        cu_seqlens,
+        max_seqlen,
+        dropout_p,
+        causal=causal,
+        window_size=window_size,
+        alibi_slopes=alibi_slopes,
+        deterministic=deterministic,
+        return_attn_probs=True,
+    )
+    out = output_pad_fn(out_unpad)
+    if dropout_p > 0.0:
+        # TODO - move to c++ mha_varlen_fwd()
+        S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)
+        S_dmask = pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens, seqlen, seqlen)
+
+        S_dmask_converted = convert_flash_attn_S_to_softmax(
+            S_dmask,
+            seqlen,
+            seqlen,
+            key_padding_mask,
+            key_padding_mask,
+            d,
+            dropout_p > 0.0,
+            causal=causal,
+            window_size=window_size,
+        )
+
+        dropout_mask = S_dmask_converted >= 0
+        # CK does not return P. Hence, we don't test the attn here.
+    else:
+        dropout_mask = None
+
+    out_ref, attn_ref = attention_qkvpacked_ref(
+        qkv,
+        key_padding_mask,
+        attn_bias,
+        dropout_p,
+        dropout_mask,
+        causal=causal,
+        window_size=window_size,
+    )
+    out_pt, attn_pt = attention_qkvpacked_ref(
+        qkv,
+        key_padding_mask,
+        attn_bias,
+        dropout_p,
+        dropout_mask,
+        causal=causal,
+        window_size=window_size,
+        upcast=False,
+        reorder_ops=True,
+    )
+    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
+    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
+    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
+    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
+
+    # Check that FlashAttention's numerical error is at most twice the numerical error
+    # of a Pytorch implementation.
+    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
+
+    g = torch.randn_like(out)
+    if is_bwd_hdim_supported(d):
+        (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)
+        dqkv = dqkv_pad_fn(dqkv_unpad)
+        (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
+        (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
+        print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
+        print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
+        print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
+        print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}")
+        print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}")
+        print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}")
+        print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}")
+        print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}")
+
+        # TODO - use 10 times to check, wait for ck to change dq type to f32
+        assert (dqkv - dqkv_ref).abs().max().item() <= 10 * (dqkv_pt - dqkv_ref).abs().max().item()
+
+
+@pytest.mark.parametrize("kvpacked", [True, False])
+@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
+@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
+@pytest.mark.parametrize("deterministic", [False])
+@pytest.mark.parametrize("alibi", [False, True])
+@pytest.mark.parametrize("local", [False, True])
+@pytest.mark.parametrize("causal", [False, True])
+@pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256])
+@pytest.mark.parametrize(
+    "seqlen_q,seqlen_k",
+    [
+        (113, 203),
+        (128, 217),
+        (113, 211),
+        (108, 256),
+        (256, 512),
+        (512, 256),
+        (1024, 1024),
+        (1023, 1024),
+        (1024, 1023),
+        (2048, 2048),
+    ],
+)
+@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
+def test_flash_attn_output(
+    seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
+):
+    device = "cuda"
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 4
+    nheads = 9
+    nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
+    assert nheads % nheads_k == 0
+    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
+    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    if kvpacked:
+        kv = torch.randn(
+            batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
+        )
+    else:
+        k = torch.randn(
+            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
+        )
+        v = torch.randn(
+            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
+        )
+    if alibi:
+        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
+        attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal)
+    else:
+        alibi_slopes, attn_bias = None, None
+
+    if kvpacked:
+        out, lse, S_dmask = flash_attn_kvpacked_func(
+            q,
+            kv,
+            dropout_p,
+            causal=causal,
+            window_size=window_size,
+            alibi_slopes=alibi_slopes,
+            deterministic=deterministic,
+            return_attn_probs=True,
+        )
+    else:
+        out, lse, S_dmask = flash_attn_func(
+            q,
+            k,
+            v,
+            dropout_p,
+            causal=causal,
+            window_size=window_size,
+            alibi_slopes=alibi_slopes,
+            deterministic=deterministic,
+            return_attn_probs=True,
+        )
+    if dropout_p > 0.0:
+        # TODO - move to c++ mha_varlen_fwd()
+        S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)
+        S_dmask_converted = convert_flash_attn_S_to_softmax(
+            S_dmask,
+            seqlen_q,
+            seqlen_k,
+            None,
+            None,
+            d,
+            dropout_p > 0.0,
+            causal=causal,
+            window_size=window_size,
+        )
+        dropout_mask = S_dmask_converted >= 0
+        if kvpacked:
+            kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
+            k_rep, v_rep = kv_rep.unbind(dim=2)
+        else:
+            k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
+            v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
+        # CK does not return P. Hence, we don't test the attn here.
+    else:
+        dropout_mask = None
+
+    if kvpacked:
+        out_ref, attn_ref = attention_kvpacked_ref(
+            q,
+            kv,
+            None,
+            None,
+            attn_bias,
+            dropout_p,
+            dropout_mask,
+            causal=causal,
+            window_size=window_size,
+        )
+        out_pt, attn_pt = attention_kvpacked_ref(
+            q,
+            kv,
+            None,
+            None,
+            attn_bias,
+            dropout_p,
+            dropout_mask,
+            causal=causal,
+            window_size=window_size,
+            upcast=False,
+            reorder_ops=True,
+        )
+    else:
+        out_ref, attn_ref = attention_ref(
+            q,
+            k,
+            v,
+            None,
+            None,
+            attn_bias,
+            dropout_p,
+            dropout_mask,
+            causal=causal,
+            window_size=window_size,
+        )
+        out_pt, attn_pt = attention_ref(
+            q,
+            k,
+            v,
+            None,
+            None,
+            attn_bias,
+            dropout_p,
+            dropout_mask,
+            causal=causal,
+            window_size=window_size,
+            upcast=False,
+            reorder_ops=True,
+        )
+
+    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
+    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
+    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
+    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
+
+    # Check that FlashAttention's numerical error is at most twice the numerical error
+    # of a Pytorch implementation.
+    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
+
+    g = torch.randn_like(out)
+    if is_bwd_hdim_supported(d):
+        if kvpacked:
+            (
+                dq,
+                dkv,
+            ) = torch.autograd.grad(out, (q, kv), g)
+            dk, dv = dkv.unbind(2)
+            (
+                dq_ref,
+                dkv_ref,
+            ) = torch.autograd.grad(out_ref, (q, kv), g)
+            dk_ref, dv_ref = dkv_ref.unbind(2)
+            (
+                dq_pt,
+                dkv_pt,
+            ) = torch.autograd.grad(out_pt, (q, kv), g)
+            dk_pt, dv_pt = dkv_pt.unbind(2)
+        else:
+            (
+                dq,
+                dk,
+                dv,
+            ) = torch.autograd.grad(out, (q, k, v), g)
+            (
+                dq_ref,
+                dk_ref,
+                dv_ref,
+            ) = torch.autograd.grad(out_ref, (q, k, v), g)
+            (
+                dq_pt,
+                dk_pt,
+                dv_pt,
+            ) = torch.autograd.grad(out_pt, (q, k, v), g)
+        print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
+        print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
+        print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
+        print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
+        print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
+        print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
+        print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
+        print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
+        print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
+        print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
+        print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
+        print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
+
+        # TODO - use 10 times to check, wait for ck to change dq type to f32
+        assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item()
+        assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item()
+        assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item()
+
+
+@pytest.mark.parametrize("kvpacked", [True, False])
+@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
+@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
+@pytest.mark.parametrize("deterministic", [False, True])
+@pytest.mark.parametrize("alibi", [False, True])
+@pytest.mark.parametrize("local", [False, True])
+@pytest.mark.parametrize("causal", [False, True])
+@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
+@pytest.mark.parametrize(
+    "seqlen_q,seqlen_k",
+    [
+        (1, 147),
+        (113, 203),
+        (128, 217),
+        (113, 211),
+        (108, 256),
+        (256, 512),
+        (512, 256),
+        (1024, 1024),
+        (1023, 1024),
+        (1024, 1023),
+        (2048, 2048),
+    ],
+)
+@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
+def test_flash_attn_varlen_output(
+    seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
+):
+    device = "cuda"
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 4
+    nheads = 9
+    nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
+    assert nheads % nheads_k == 0
+    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
+    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    if kvpacked:
+        kv = torch.randn(
+            batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
+        )
+    else:
+        k = torch.randn(
+            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
+        )
+        v = torch.randn(
+            batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
+        )
+
+    query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
+    key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
+    # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
+    if alibi:
+        alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
+        attn_bias = attn_bias_from_alibi_slopes(
+            alibi_slopes, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, causal=causal
+        )
+    else:
+        alibi_slopes, attn_bias = None, None
+
+    if kvpacked:
+        (
+            q_unpad,
+            kv_unpad,
+            cu_seqlens_q,
+            cu_seqlens_k,
+            max_seqlen_q,
+            max_seqlen_k,
+            q,
+            kv,
+            output_pad_fn,
+            dq_pad_fn,
+            dkv_pad_fn,
+        ) = generate_qkv(q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True)
+        out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func(
+            q_unpad,
+            kv_unpad,
+            cu_seqlens_q,
+            cu_seqlens_k,
+            max_seqlen_q,
+            max_seqlen_k,
+            dropout_p,
+            causal=causal,
+            window_size=window_size,
+            alibi_slopes=alibi_slopes,
+            deterministic=deterministic,
+            return_attn_probs=True,
+        )
+    else:
+        (
+            q_unpad,
+            k_unpad,
+            v_unpad,
+            cu_seqlens_q,
+            cu_seqlens_k,
+            max_seqlen_q,
+            max_seqlen_k,
+            q,
+            k,
+            v,
+            output_pad_fn,
+            dq_pad_fn,
+            dk_pad_fn,
+        ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
+        out_unpad, sm_lse, S_dmask = flash_attn_varlen_func(
+            q_unpad,
+            k_unpad,
+            v_unpad,
+            cu_seqlens_q,
+            cu_seqlens_k,
+            max_seqlen_q,
+            max_seqlen_k,
+            dropout_p,
+            causal=causal,
+            window_size=window_size,
+            alibi_slopes=alibi_slopes,
+            deterministic=deterministic,
+            return_attn_probs=True,
+        )
+    out = output_pad_fn(out_unpad)
+    if dropout_p > 0.0:
+        # TODO - move to c++ mha_varlen_fwd()
+        S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p)
+        S_dmask = pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen_q, seqlen_k)
+        S_dmask_converted = convert_flash_attn_S_to_softmax(
+            S_dmask,
+            seqlen_q,
+            seqlen_k,
+            query_padding_mask,
+            key_padding_mask,
+            d,
+            dropout_p > 0.0,
+            causal=causal,
+            window_size=window_size,
+        )
+        dropout_mask = S_dmask_converted >= 0
+        if kvpacked:
+            kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
+            k_rep, v_rep = kv_rep.unbind(dim=2)
+        else:
+            k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
+            v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
+        # CK does not return P. Hence, we don't test the attn here.
+    else:
+        dropout_mask = None
+
+    if kvpacked:
+        out_ref, attn_ref = attention_kvpacked_ref(
+            q,
+            kv,
+            query_padding_mask,
+            key_padding_mask,
+            attn_bias,
+            dropout_p,
+            dropout_mask,
+            causal=causal,
+            window_size=window_size,
+        )
+        out_pt, attn_pt = attention_kvpacked_ref(
+            q,
+            kv,
+            query_padding_mask,
+            key_padding_mask,
+            attn_bias,
+            dropout_p,
+            dropout_mask,
+            causal=causal,
+            window_size=window_size,
+            upcast=False,
+            reorder_ops=True,
+        )
+    else:
+        out_ref, attn_ref = attention_ref(
+            q,
+            k,
+            v,
+            query_padding_mask,
+            key_padding_mask,
+            attn_bias,
+            dropout_p,
+            dropout_mask,
+            causal=causal,
+            window_size=window_size,
+        )
+        out_pt, attn_pt = attention_ref(
+            q,
+            k,
+            v,
+            query_padding_mask,
+            key_padding_mask,
+            attn_bias,
+            dropout_p,
+            dropout_mask,
+            causal=causal,
+            window_size=window_size,
+            upcast=False,
+            reorder_ops=True,
+        )
+
+    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
+    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
+    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
+    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
+
+    # Check that FlashAttention's numerical error is at most 4 times the numerical error
+    # of a Pytorch implementation.
+    assert (out - out_ref).abs().max().item() <= 4 * (out_pt - out_ref).abs().max().item()
+
+    g = torch.randn_like(out)
+    if is_bwd_hdim_supported(d):
+        if kvpacked:
+            (
+                dq_unpad,
+                dkv_unpad,
+            ) = torch.autograd.grad(out, (q_unpad, kv_unpad), g)
+            dk, dv = dkv_pad_fn(dkv_unpad).unbind(2)
+            (
+                dq_ref,
+                dkv_ref,
+            ) = torch.autograd.grad(out_ref, (q, kv), g)
+            dk_ref, dv_ref = dkv_ref.unbind(2)
+            (
+                dq_pt,
+                dkv_pt,
+            ) = torch.autograd.grad(out_pt, (q, kv), g)
+            dk_pt, dv_pt = dkv_pt.unbind(2)
+        else:
+            (
+                dq_unpad,
+                dk_unpad,
+                dv_unpad,
+            ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
+            dk = dk_pad_fn(dk_unpad)
+            dv = dk_pad_fn(dv_unpad)
+            (
+                dq_ref,
+                dk_ref,
+                dv_ref,
+            ) = torch.autograd.grad(out_ref, (q, k, v), g)
+            (
+                dq_pt,
+                dk_pt,
+                dv_pt,
+            ) = torch.autograd.grad(out_pt, (q, k, v), g)
+        dq = dq_pad_fn(dq_unpad)
+        print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
+        print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
+        print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
+        print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
+        print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
+        print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
+        print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
+        print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
+        print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
+        print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
+        print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
+        print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
+
+        # TODO - use 10 times to check, wait for ck to change dq type to f32
+        assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item()
+        assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item()
+        assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item()