瀏覽代碼

Support alibi, by Sanghun Cho from Kakao Brain

* hard-code alibi in fwd

* use params.h as hun_heads

* hard-code alibi in bwd

* add alibi on/off option

* compute alibi_start, ratio outside of kernels

* fix minor merge conflict

* add test_alibi.py

* change apply_alibi() location before masking

* add alibi in splitkv kernel

* fix backward func # of returns

* add out-of-bound check in apply_alibi()

* update test_alibi.py

* update test_alibi.py for kvcache

* simplify alibi parameter interface

* fix performance issue
by computing alibi outside of branch

* update test_flash_attn_varlen_func() for left padding

* implement alibi_slopes (b, nh) loading

* optimize apply_alibi() a bit

* update test cases for alibi_slopes loading

* reflect stylistic comments

* disable "seqlenq_ngroups_swapped" when using alibi

---------

Co-authored-by: monk.detective <monk.detective@kakaobrain.com>
Sanghun Cho 1 年之前
父節點
當前提交
e4f726fc44

+ 75 - 3
csrc/flash_attn/flash_api.cpp

@@ -258,6 +258,7 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
         bool is_causal,
         const int window_size_left,
         int window_size_right,
+        c10::optional<at::Tensor> &alibi_slopes_, // batch_size x num_heads
         const bool return_softmax,
         c10::optional<at::Generator> gen_) {
 
@@ -301,7 +302,8 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
 
     // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
     // H/t Daniel Haziza
-    const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0;
+    // TODO: how to make "seqlenq_ngroups_swapped" and ALiBi work together?
+    const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !(alibi_slopes_.has_value());
     if (seqlenq_ngroups_swapped) {
         const int ngroups = num_heads / num_heads_k;
         q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
@@ -409,6 +411,19 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
         params.philox_args = gen->philox_cuda_state(counter_offset);
     }
 
+    if (alibi_slopes_.has_value()) {
+        auto alibi_slopes = alibi_slopes_.value();
+        TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
+        CHECK_DEVICE(alibi_slopes);
+        TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
+        CHECK_SHAPE(alibi_slopes, batch_size, num_heads);
+        params.has_alibi = true;
+        params.alibi_slopes_ptr = alibi_slopes.data_ptr();
+        params.alibi_slopes_batch_stride = alibi_slopes.stride(0);
+    } else {
+        params.has_alibi = false;
+    }
+
     if (seqlen_k > 0) {
         auto stream = at::cuda::getCurrentCUDAStream().stream();
         run_mha_fwd(params, stream);
@@ -449,6 +464,7 @@ mha_varlen_fwd(const at::Tensor &q,  // total_q x num_heads x head_size, total_q
                const bool is_causal,
                const int window_size_left,
                int window_size_right,
+               c10::optional<at::Tensor> &alibi_slopes_, // b x num_heads
                const bool return_softmax,
                c10::optional<at::Generator> gen_) {
 
@@ -591,6 +607,19 @@ mha_varlen_fwd(const at::Tensor &q,  // total_q x num_heads x head_size, total_q
         params.philox_args = gen->philox_cuda_state(counter_offset);
     }
 
+    if (alibi_slopes_.has_value()) {
+        auto alibi_slopes = alibi_slopes_.value();
+        TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
+        CHECK_DEVICE(alibi_slopes);
+        TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
+        CHECK_SHAPE(alibi_slopes, batch_size, num_heads);
+        params.has_alibi = true;
+        params.alibi_slopes_ptr = alibi_slopes.data_ptr();
+        params.alibi_slopes_batch_stride = alibi_slopes.stride(0);
+    } else {
+        params.has_alibi = false;
+    }
+
     auto stream = at::cuda::getCurrentCUDAStream().stream();
     run_mha_fwd(params, stream);
 
@@ -640,6 +669,7 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
         const bool is_causal,
         const int window_size_left,
         int window_size_right,
+        c10::optional<at::Tensor> &alibi_slopes_, // batch_size x num_heads
         c10::optional<at::Generator> gen_,
         c10::optional<at::Tensor> &rng_state) {
 
@@ -813,6 +843,19 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
         params.rng_state[1] = std::get<1>(seeds);
     }
 
+    if (alibi_slopes_.has_value()) {
+        auto alibi_slopes = alibi_slopes_.value();
+        TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
+        CHECK_DEVICE(alibi_slopes);
+        TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
+        CHECK_SHAPE(alibi_slopes, batch_size, num_heads);
+        params.has_alibi = true;
+        params.alibi_slopes_ptr = alibi_slopes.data_ptr();
+        params.alibi_slopes_batch_stride = alibi_slopes.stride(0);
+    } else {
+        params.has_alibi = false;
+    }
+
     if (seqlen_q > 0) {
         launch(params, stream, /*configure=*/false);
     } else {
@@ -856,6 +899,7 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
                const bool is_causal,
                const int window_size_left,
                int window_size_right,
+               c10::optional<at::Tensor> &alibi_slopes_, // b x num_heads
                c10::optional<at::Generator> gen_,
                c10::optional<at::Tensor> &rng_state) {
 
@@ -1045,6 +1089,19 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
         params.rng_state[1] = std::get<1>(seeds);
     }
 
+    if (alibi_slopes_.has_value()) {
+        auto alibi_slopes = alibi_slopes_.value();
+        TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
+        CHECK_DEVICE(alibi_slopes);
+        TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
+        CHECK_SHAPE(alibi_slopes, batch_size, num_heads);
+        params.has_alibi = true;
+        params.alibi_slopes_ptr = alibi_slopes.data_ptr();
+        params.alibi_slopes_batch_stride = alibi_slopes.stride(0);
+    } else {
+        params.has_alibi = false;
+    }
+
     launch(params, stream, /*configure=*/false);
 
     // For MQA/GQA we need to sum dK and dV across the groups
@@ -1077,7 +1134,8 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
                 const int window_size_left,
                 int window_size_right,
                 bool is_rotary_interleaved,   // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
-                int num_splits
+                int num_splits,
+                c10::optional<at::Tensor> &alibi_slopes_ // batch_size x num_heads
                 ) {
 
     auto dprops = at::cuda::getCurrentDeviceProperties();
@@ -1121,7 +1179,8 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
 
     // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
     // H/t Daniel Haziza
-    const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0;
+    // TODO: how to make "seqlenq_ngroups_swapped" and ALiBi work together?
+    const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !(alibi_slopes_.has_value());
     if (seqlenq_ngroups_swapped) {
         const int ngroups = num_heads / num_heads_k;
         q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
@@ -1283,6 +1342,19 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
         params.oaccum_ptr = out_accum.data_ptr();
     }
 
+    if (alibi_slopes_.has_value()) {
+        auto alibi_slopes = alibi_slopes_.value();
+        TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
+        CHECK_DEVICE(alibi_slopes);
+        TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
+        CHECK_SHAPE(alibi_slopes, batch_size, num_heads);
+        params.has_alibi = true;
+        params.alibi_slopes_ptr = alibi_slopes.data_ptr();
+        params.alibi_slopes_batch_stride = alibi_slopes.stride(0);
+    } else {
+        params.has_alibi = false;
+    }
+
     auto stream = at::cuda::getCurrentCUDAStream().stream();
     // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx
     run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value());

+ 54 - 0
csrc/flash_attn/src/alibi.h

@@ -0,0 +1,54 @@
+#include <cmath>
+
+#include <cute/tensor.hpp>
+
+#include <cutlass/cutlass.h>
+#include <cutlass/array.h>
+
+#include "utils.h"
+
+namespace flash {
+
+using namespace cute;
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename Engine, typename Layout>
+inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor, 
+                                   const int col_idx_offset_,
+                                   const int max_seqlen_k, 
+                                   const int row_idx_offset_,
+                                   const int max_seqlen_q, 
+                                   const int warp_row_stride,
+                                   const int head_idx,
+                                   const float softmax_scale,
+                                   const float alibi_slope) {
+    // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
+    static_assert(Layout::rank == 2, "Only support 2D Tensor");
+    const int lane_id = threadIdx.x % 32;
+    const int row_idx_offset = row_idx_offset_;
+    const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
+    const float alibi_slope_unscaled = alibi_slope / softmax_scale;
+    #pragma unroll
+    for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
+        const int row_idx_base = row_idx_offset + mi * warp_row_stride;
+        #pragma unroll
+        for (int i = 0; i < size<0, 0>(tensor); ++i) {
+            const int row_idx = row_idx_base + i * 8;
+            #pragma unroll
+            for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
+                const int col_idx_base = col_idx_offset + nj * 8;
+                #pragma unroll
+                for (int j = 0; j < size<1, 0>(tensor); ++j) {
+                    const int col_idx = col_idx_base + j;
+                    const float alibi = alibi_slope_unscaled * col_idx;
+                    if (col_idx < max_seqlen_k && row_idx < max_seqlen_q) {
+                        tensor(make_coord(i, mi), make_coord(j, nj)) += alibi;
+                    }
+                }
+            }
+        }
+    }
+}
+
+}  // namespace flash

+ 7 - 0
csrc/flash_attn/src/flash.h

@@ -130,6 +130,13 @@ struct Flash_fwd_params : public Qkv_params {
     bool is_rotary_interleaved;
 
     int num_splits;  // For split-KV version
+
+    // float alibi_start;
+    // float alibi_ratio;
+
+    bool has_alibi;
+    void * __restrict__ alibi_slopes_ptr;
+    index_t alibi_slopes_batch_stride;
 };
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////

+ 69 - 11
csrc/flash_attn/src/flash_bwd_kernel.h

@@ -15,6 +15,8 @@
 #include "utils.h"
 #include "softmax.h"
 
+#include "alibi.h"
+
 namespace flash {
 
 using namespace cute;
@@ -422,7 +424,7 @@ inline __device__ void convert_dKV(const Params &params) {
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
 inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {
 
     using Element = typename Kernel_traits::Element;
@@ -790,6 +792,19 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
     clear(acc_dv);
     clear(acc_dk);
 
+    float alibi_slope = 0.0f;
+    if (Has_alibi) {
+        Tensor gAS = make_tensor(
+        make_gmem_ptr(
+            reinterpret_cast<ElementAccum *>(params.alibi_slopes_ptr) 
+            + bidb * params.alibi_slopes_batch_stride + bidh
+        ),
+        Shape<_1>{});
+        Tensor rAS = make_fragment_like(gAS);
+        cute::copy(gAS, rAS);
+        alibi_slope = rAS(0);
+    }
+
     for (; m_block >= m_block_min; --m_block) {
         Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_N, MMA_N)
         clear(acc_s);
@@ -813,6 +828,20 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
         // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
         Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
         // if (cute::thread(32, 0)) { print(scores); }
+
+        if (Has_alibi) {
+            flash::apply_alibi(
+                scores, 
+                n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
+                binfo.actual_seqlen_k, 
+                m_block * kBlockM + get<0>(taccScS_row(0)),
+                binfo.actual_seqlen_q, 
+                AtomLayoutMS * 16,
+                bidh, params.scale_softmax, 
+                alibi_slope
+            );
+        }
+        
         // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
         // actual_seqlen_k, because acc_s would be some finite value for those indices.
         // In the end when we multiply with K to get dQ, the corresponding values of K would be 0,
@@ -849,6 +878,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
             }
 
         }
+
         // if (cute::thread(32, 0)) { print(scores); }
         // Compute the exponential value.
         flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
@@ -1114,7 +1144,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, typename Params>
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_N, bool Is_even_K, typename Params>
 inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
 
     using Element = typename Kernel_traits::Element;
@@ -1373,6 +1403,19 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
 
     clear(acc_dq);
 
+    float alibi_slope = 0.0f;
+    if (Has_alibi) {
+        Tensor gAS = make_tensor(
+        make_gmem_ptr(
+            reinterpret_cast<ElementAccum *>(params.alibi_slopes_ptr) 
+            + bidb * params.alibi_slopes_batch_stride + bidh
+        ),
+        Shape<_1>{});
+        Tensor rAS = make_fragment_like(gAS);
+        cute::copy(gAS, rAS);
+        alibi_slope = rAS(0);
+    }
+
     for (; n_block >= 0; --n_block) {
         Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_M_SdP, MMA_N)
         clear(acc_s);
@@ -1384,6 +1427,20 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
 
         // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
         Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
+
+        if (Has_alibi) {
+            flash::apply_alibi(
+                scores, 
+                n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
+                binfo.actual_seqlen_k, 
+                m_block * kBlockM + get<0>(taccScS_row(0)),
+                binfo.actual_seqlen_q, 
+                AtomLayoutMS * 16,
+                bidh, params.scale_softmax, 
+                alibi_slope
+            );
+        }
+
         // We don't need to mask out the elements beyond actual_seqlen_k, because acc_s would
         // be some finite value for those indices. In the end when we multiply with K to get dQ,
         // the corresponding values of K would be 0, so the result would still be correct.
@@ -1394,6 +1451,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
                                      binfo.actual_seqlen_q,
                                      AtomLayoutMS * 16);
         }
+
         // Compute the exponential value.
         flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
         if (Is_dropout) {
@@ -1536,7 +1594,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_M, bool Is_even_K, typename Params>
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K, typename Params>
 inline __device__ void compute_dq_dk_dv(const Params &params) {
 
     // The block index for the batch.
@@ -1550,20 +1608,20 @@ inline __device__ void compute_dq_dk_dv(const Params &params) {
 
     const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
     if (n_block_max == 1) {
-        compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, true, true>(params, bidb, bidh, 0);
+        compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, true, true>(params, bidb, bidh, 0);
     } else {
         // Iterating backward from n_block_max - 1 to 0 might save 1 register
-        compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, true, false>(params, bidb, bidh, n_block_max - 1);
+        compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, true, false>(params, bidb, bidh, n_block_max - 1);
         for (int n_block = n_block_max - 2; n_block > 0; n_block--) {
-            compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, false, false>(params, bidb, bidh, n_block);
+            compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, false, false>(params, bidb, bidh, n_block);
         }
-        compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, false, true>(params, bidb, bidh, 0);
+        compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, false, true>(params, bidb, bidh, 0);
     }
 }
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, typename Params>
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, typename Params>
 inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
 
     const int n_block = blockIdx.x;
@@ -1572,12 +1630,12 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
     // The block index for the head.
     const int bidh = blockIdx.z;
 
-    compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
+    compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
 }
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, typename Params>
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_N, bool Is_even_K, typename Params>
 inline __device__ void compute_dq_dk_dv_seqq_parallel(const Params &params) {
 
     const int m_block = blockIdx.x;
@@ -1586,7 +1644,7 @@ inline __device__ void compute_dq_dk_dv_seqq_parallel(const Params &params) {
     // The block index for the head.
     const int bidh = blockIdx.z;
 
-    compute_dq_dk_dv_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K>(params, bidb, bidh, m_block);
+    compute_dq_dk_dv_1rowblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_N, Is_even_K>(params, bidb, bidh, m_block);
 }
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////

+ 30 - 26
csrc/flash_attn/src/flash_bwd_launch_template.h

@@ -18,20 +18,20 @@ __global__ void flash_bwd_clear_dkvaccum_kernel(Flash_bwd_params params) {
     flash::clear_dKVaccum<Kernel_traits>(params);
 }
 
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_M, bool Is_even_K>
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K>
 __global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) {
-    flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K>(params);
+    flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
 }
 
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K>
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K>
 __global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) {
     static_assert(!(Is_causal && Is_local));  // If Is_local is true, Is_causal should be false
-    flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K>(params);
+    flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
 }
 
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K>
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_N, bool Is_even_K>
 __global__ void flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel(Flash_bwd_params params) {
-    flash::compute_dq_dk_dv_seqq_parallel<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K>(params);
+    flash::compute_dq_dk_dv_seqq_parallel<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_N, Is_even_K>(params);
 }
 
 template<typename Kernel_traits>
@@ -64,17 +64,19 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
         BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
             BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
                 BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
-                    // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
-                    // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
-                    // If Is_local, set Is_causal to false
-                    auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
-                    // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, true>;
-                    if (smem_size_dq_dk_dv >= 48 * 1024)  {
-                        C10_CUDA_CHECK(cudaFuncSetAttribute(
-                            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
-                    }
-                    kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
-                    C10_CUDA_KERNEL_LAUNCH_CHECK();
+                    BOOL_SWITCH(params.has_alibi, Has_alibi, [&] {
+                        // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
+                        // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
+                        // If Is_local, set Is_causal to false
+                        auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal && !Is_local, Is_local, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
+                        // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
+                        if (smem_size_dq_dk_dv >= 48 * 1024)  {
+                            C10_CUDA_CHECK(cudaFuncSetAttribute(
+                                kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
+                        }
+                        kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
+                        C10_CUDA_KERNEL_LAUNCH_CHECK();
+                    });
                 });
             });
         });
@@ -107,15 +109,17 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream,
     BOOL_SWITCH(params.is_causal, Is_causal, [&] {
         BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
             BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
-                // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
-                auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst && IsEvenKConst, IsEvenKConst>;
-                // auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
-                if (smem_size_dq_dk_dv >= 48 * 1024)  {
-                    C10_CUDA_CHECK(cudaFuncSetAttribute(
-                        kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
-                }
-                kernel<<<grid_m, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
-                C10_CUDA_KERNEL_LAUNCH_CHECK();
+                BOOL_SWITCH(params.has_alibi, Has_alibi, [&] {
+                    // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
+                    auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Has_alibi, IsEvenNConst && IsEvenKConst, IsEvenKConst>;
+                    // auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
+                    if (smem_size_dq_dk_dv >= 48 * 1024)  {
+                        C10_CUDA_CHECK(cudaFuncSetAttribute(
+                            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
+                    }
+                    kernel<<<grid_m, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
+                    C10_CUDA_KERNEL_LAUNCH_CHECK();
+                });
             });
         });
     });

+ 97 - 6
csrc/flash_attn/src/flash_fwd_kernel.h

@@ -15,6 +15,8 @@
 #include "utils.h"
 #include "softmax.h"
 
+#include "alibi.h"
+
 namespace flash {
 
 using namespace cute;
@@ -71,7 +73,7 @@ inline __device__ void write_softmax_to_gmem(
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
 inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
 
     using Element = typename Kernel_traits::Element;
@@ -326,6 +328,22 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
     // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
     // We will have at least 1 "masking" iteration.
 
+    float alibi_slope = 0.0f;
+    if (Has_alibi) {
+        Tensor gAS = make_tensor(
+        make_gmem_ptr(
+            reinterpret_cast<ElementAccum *>(params.alibi_slopes_ptr) 
+            + bidb * params.alibi_slopes_batch_stride + bidh
+        ),
+        Shape<_1>{});
+        Tensor rAS = make_fragment_like(gAS);
+        cute::copy(gAS, rAS);
+        alibi_slope = rAS(0);
+        // if (m_block == 0 && tidx == 0) {
+        //     printf("%d,%d,%f\n", bidb, bidh, alibi_slope);
+        // }
+    }
+
     // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
     // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
     constexpr int n_masking_steps = (!Is_causal && !Is_local)
@@ -362,6 +380,20 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
         // We don't put the masking before the matmul S = Q K^T because we don't clear sK
         // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
         // can produce Inf / NaN.
+
+        if (Has_alibi) {
+            flash::apply_alibi(
+                scores, 
+                n_block * kBlockN, 
+                binfo.actual_seqlen_k,
+                m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
+                binfo.actual_seqlen_q, 
+                kNWarps * 16,
+                bidh, params.scale_softmax, 
+                alibi_slope
+            );
+        }
+
         if (!Is_causal && !Is_local) {
             if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
         } else {
@@ -466,6 +498,20 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
 
         // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
         Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
+        
+        if (Has_alibi) {
+            flash::apply_alibi(
+                scores, 
+                n_block * kBlockN, 
+                binfo.actual_seqlen_k,
+                m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
+                binfo.actual_seqlen_q, 
+                kNWarps * 16,
+                bidh, params.scale_softmax, 
+                alibi_slope
+            );
+        }
+        
         if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
             flash::apply_mask_local(
                 scores, n_block * kBlockN, binfo.actual_seqlen_k,
@@ -474,6 +520,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
                 params.window_size_left, params.window_size_right
             );
         }
+
         softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
 
         Tensor rP = flash::convert_type<Element>(scores);
@@ -581,7 +628,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
+template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
 inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
 
     using Element = typename Kernel_traits::Element;
@@ -909,6 +956,22 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
     // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
     // We will have at least 1 "masking" iteration.
 
+    float alibi_slope = 0.0f;
+    if (Has_alibi) {
+        Tensor gAS = make_tensor(
+        make_gmem_ptr(
+            reinterpret_cast<ElementAccum *>(params.alibi_slopes_ptr) 
+            + bidb * params.alibi_slopes_batch_stride + bidh
+        ),
+        Shape<_1>{});
+        Tensor rAS = make_fragment_like(gAS);
+        cute::copy(gAS, rAS);
+        alibi_slope = rAS(0);
+        // if (m_block == 0 && tidx == 0) {
+        //     printf("%d,%d,%f\n", bidb, bidh, alibi_slope);
+        // }
+    }
+
     // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
     // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
     constexpr int n_masking_steps = (!Is_causal && !Is_local)
@@ -941,6 +1004,20 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
 
         // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
         Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
+
+        if (Has_alibi) {
+            flash::apply_alibi(
+                scores, 
+                n_block * kBlockN, 
+                binfo.actual_seqlen_k,
+                m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
+                binfo.actual_seqlen_q, 
+                kNWarps * 16,
+                bidh, params.scale_softmax, 
+                alibi_slope
+            );
+        }
+
         // if (cute::thread0()) { print(scores); }
         // We don't put the masking before the matmul S = Q K^T because we don't clear sK
         // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
@@ -1020,6 +1097,20 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
 
         // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
         Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
+
+        if (Has_alibi) {
+            flash::apply_alibi(
+                scores, 
+                n_block * kBlockN, 
+                binfo.actual_seqlen_k,
+                m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
+                binfo.actual_seqlen_q, 
+                kNWarps * 16,
+                bidh, params.scale_softmax, 
+                alibi_slope
+            );
+        }
+
         if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
             flash::apply_mask_local(
                 scores, n_block * kBlockN, binfo.actual_seqlen_k,
@@ -1131,7 +1222,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
 inline __device__ void compute_attn(const Params &params) {
     const int m_block = blockIdx.x;
     // The block index for the batch.
@@ -1147,12 +1238,12 @@ inline __device__ void compute_attn(const Params &params) {
     // the attention matrix. This way, as long as we have the batch, head, and the location of
     // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
 
-    flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
+    flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
 }
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
+template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
 inline __device__ void compute_attn_splitkv(const Params &params) {
     const int m_block = blockIdx.x;
     // The block index for the batch.
@@ -1161,7 +1252,7 @@ inline __device__ void compute_attn_splitkv(const Params &params) {
     const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
     const int n_split_idx = Split ? blockIdx.y : 0;
     const int num_n_splits = Split ? gridDim.y : 1;
-    flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Is_even_MN, Is_even_K, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
+    flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
 }
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////

+ 38 - 34
csrc/flash_attn/src/flash_fwd_launch_template.h

@@ -10,15 +10,15 @@
 #include "flash.h"
 #include "flash_fwd_kernel.h"
 
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
 __global__ void flash_fwd_kernel(Flash_fwd_params params) {
     static_assert(!(Is_causal && Is_local));  // If Is_local is true, Is_causal should be false
-    flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Return_softmax>(params);
+    flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
 }
 
-template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV>
+template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV>
 __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) {
-    flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Is_even_MN, Is_even_K, Split, Append_KV>(params);
+    flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params);
 }
 
 template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K>
@@ -45,24 +45,26 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
         BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
             BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
                 BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
-                    // Will only return softmax if dropout, to reduce compilation time.
-                    // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
-                    // If return_softmax, set IsEvenMNConst to false to reduce number of templates
-                    // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
-                    // If Is_local, set Is_causal to false
-                    auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
-                    // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
-                    // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
-                    if (smem_size >= 48 * 1024) {
-                        C10_CUDA_CHECK(cudaFuncSetAttribute(
-                            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
-                    }
-                    // int ctas_per_sm;
-                    // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
-                    //     &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
-                    // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
-                    kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
-                    C10_CUDA_KERNEL_LAUNCH_CHECK();
+                    BOOL_SWITCH(params.has_alibi, Has_alibi, [&] {
+                        // Will only return softmax if dropout, to reduce compilation time.
+                        // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
+                        // If return_softmax, set IsEvenMNConst to false to reduce number of templates
+                        // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
+                        // If Is_local, set Is_causal to false
+                        auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
+                        // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
+                        // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
+                        if (smem_size >= 48 * 1024) {
+                            C10_CUDA_CHECK(cudaFuncSetAttribute(
+                                kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+                        }
+                        // int ctas_per_sm;
+                        // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+                        //     &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
+                        // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
+                        kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
+                        C10_CUDA_KERNEL_LAUNCH_CHECK();
+                    });
                 });
             });
         });
@@ -84,18 +86,20 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
                 BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
                     BOOL_SWITCH(params.num_splits > 1, Split, [&] {
                         BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
-                            // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
-                            // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
-                            // If Is_local, set Is_causal to false
-                            auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>;
-                            // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
-                            // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
-                            if (smem_size >= 48 * 1024) {
-                                C10_CUDA_CHECK(cudaFuncSetAttribute(
-                                    kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
-                            }
-                            kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
-                            C10_CUDA_KERNEL_LAUNCH_CHECK();
+                            BOOL_SWITCH(params.has_alibi, Has_alibi, [&] {
+                                // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
+                                // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
+                                // If Is_local, set Is_causal to false
+                                auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>;
+                                // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
+                                // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
+                                if (smem_size >= 48 * 1024) {
+                                    C10_CUDA_CHECK(cudaFuncSetAttribute(
+                                        kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+                                }
+                                kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
+                                C10_CUDA_KERNEL_LAUNCH_CHECK();
+                            });
                         });
                     });
                 });

+ 52 - 13
flash_attn/flash_attn_interface.py

@@ -43,7 +43,7 @@ def _get_block_size(device, head_dim, is_dropout, is_causal):
         return (128, 64) if is_sm80 else (64, 64)
 
 
-def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, return_softmax):
+def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax):
     maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
     q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
     out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
@@ -56,6 +56,7 @@ def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size,
         causal,
         window_size[0],
         window_size[1],
+        alibi_slopes,
         return_softmax,
         None,
     )
@@ -74,6 +75,7 @@ def _flash_attn_varlen_forward(
     softmax_scale,
     causal,
     window_size,
+    alibi_slopes,
     return_softmax,
 ):
     maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
@@ -94,6 +96,7 @@ def _flash_attn_varlen_forward(
         causal,
         window_size[0],
         window_size[1],
+        alibi_slopes,
         return_softmax,
         None,
     )
@@ -116,6 +119,7 @@ def _flash_attn_backward(
     softmax_scale,
     causal,
     window_size,
+    alibi_slopes,
     rng_state=None,
 ):
     maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
@@ -136,6 +140,7 @@ def _flash_attn_backward(
         causal,
         window_size[0],
         window_size[1],
+        alibi_slopes,
         None,
         rng_state,
     )
@@ -160,6 +165,7 @@ def _flash_attn_varlen_backward(
     softmax_scale,
     causal,
     window_size,
+    alibi_slopes,
     rng_state=None,
 ):
     maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
@@ -185,6 +191,7 @@ def _flash_attn_varlen_backward(
         causal,
         window_size[0],
         window_size[1],
+        alibi_slopes,
         None,
         rng_state,
     )
@@ -195,7 +202,7 @@ def _flash_attn_varlen_backward(
 
 class FlashAttnQKVPackedFunc(torch.autograd.Function):
     @staticmethod
-    def forward(ctx, qkv, dropout_p, softmax_scale, causal, window_size, return_softmax):
+    def forward(ctx, qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax):
         if softmax_scale is None:
             softmax_scale = qkv.shape[-1] ** (-0.5)
         out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
@@ -206,6 +213,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
             softmax_scale,
             causal=causal,
             window_size=window_size,
+            alibi_slopes=alibi_slopes,
             return_softmax=return_softmax and dropout_p > 0,
         )
         ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
@@ -213,6 +221,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
         ctx.softmax_scale = softmax_scale
         ctx.causal = causal
         ctx.window_size = window_size
+        ctx.alibi_slopes = alibi_slopes
         return out if not return_softmax else (out, softmax_lse, S_dmask)
 
     @staticmethod
@@ -234,10 +243,11 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
             ctx.softmax_scale,
             ctx.causal,
             ctx.window_size,
+            ctx.alibi_slopes,
             rng_state=rng_state,
         )
         dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
-        return dqkv, None, None, None, None, None
+        return dqkv, None, None, None, None, None, None
 
 
 class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
@@ -251,6 +261,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
         softmax_scale,
         causal,
         window_size,
+        alibi_slopes,
         return_softmax,
     ):
         if softmax_scale is None:
@@ -267,6 +278,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
             softmax_scale,
             causal=causal,
             window_size=window_size,
+            alibi_slopes=alibi_slopes,
             return_softmax=return_softmax and dropout_p > 0,
         )
         ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
@@ -275,6 +287,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
         ctx.softmax_scale = softmax_scale
         ctx.causal = causal
         ctx.window_size = window_size
+        ctx.alibi_slopes = alibi_slopes
         return out if not return_softmax else (out, softmax_lse, S_dmask)
 
     @staticmethod
@@ -300,15 +313,16 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
             ctx.softmax_scale,
             ctx.causal,
             ctx.window_size,
+            ctx.alibi_slopes,
             rng_state=rng_state,
         )
         dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
-        return dqkv, None, None, None, None, None, None, None
+        return dqkv, None, None, None, None, None, None, None, None
 
 
 class FlashAttnKVPackedFunc(torch.autograd.Function):
     @staticmethod
-    def forward(ctx, q, kv, dropout_p, softmax_scale, causal, window_size, return_softmax):
+    def forward(ctx, q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax):
         if softmax_scale is None:
             softmax_scale = q.shape[-1] ** (-0.5)
         out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
@@ -319,6 +333,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
             softmax_scale,
             causal=causal,
             window_size=window_size,
+            alibi_slopes=alibi_slopes,
             return_softmax=return_softmax and dropout_p > 0,
         )
         ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
@@ -326,6 +341,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
         ctx.softmax_scale = softmax_scale
         ctx.causal = causal
         ctx.window_size = window_size
+        ctx.alibi_slopes = alibi_slopes
         return out if not return_softmax else (out, softmax_lse, S_dmask)
 
     @staticmethod
@@ -348,11 +364,12 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
             ctx.softmax_scale,
             ctx.causal,
             ctx.window_size,
+            ctx.alibi_slopes,
             rng_state=rng_state,
         )
         dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
         dkv = dkv[..., : dout.shape[-1]]
-        return dq, dkv, None, None, None, None, None
+        return dq, dkv, None, None, None, None, None, None
 
 
 class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
@@ -369,6 +386,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
         softmax_scale,
         causal,
         window_size,
+        alibi_slopes,
         return_softmax,
     ):
         if softmax_scale is None:
@@ -385,6 +403,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
             softmax_scale,
             causal=causal,
             window_size=window_size,
+            alibi_slopes=alibi_slopes,
             return_softmax=return_softmax and dropout_p > 0,
         )
         ctx.save_for_backward(
@@ -396,6 +415,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
         ctx.softmax_scale = softmax_scale
         ctx.causal = causal
         ctx.window_size = window_size
+        ctx.alibi_slopes = alibi_slopes
         return out if not return_softmax else (out, softmax_lse, S_dmask)
 
     @staticmethod
@@ -422,16 +442,17 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
             ctx.softmax_scale,
             ctx.causal,
             ctx.window_size,
+            ctx.alibi_slopes,
             rng_state=rng_state,
         )
         dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
         dkv = dkv[..., : dout.shape[-1]]
-        return dq, dkv, None, None, None, None, None, None, None, None, None
+        return dq, dkv, None, None, None, None, None, None, None, None, None, None
 
 
 class FlashAttnFunc(torch.autograd.Function):
     @staticmethod
-    def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, return_softmax):
+    def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax):
         if softmax_scale is None:
             softmax_scale = q.shape[-1] ** (-0.5)
         out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
@@ -442,6 +463,7 @@ class FlashAttnFunc(torch.autograd.Function):
             softmax_scale,
             causal=causal,
             window_size=window_size,
+            alibi_slopes=alibi_slopes,
             return_softmax=return_softmax and dropout_p > 0,
         )
         ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
@@ -449,6 +471,7 @@ class FlashAttnFunc(torch.autograd.Function):
         ctx.softmax_scale = softmax_scale
         ctx.causal = causal
         ctx.window_size = window_size
+        ctx.alibi_slopes = alibi_slopes
         return out if not return_softmax else (out, softmax_lse, S_dmask)
 
     @staticmethod
@@ -469,12 +492,13 @@ class FlashAttnFunc(torch.autograd.Function):
             ctx.softmax_scale,
             ctx.causal,
             ctx.window_size,
+            ctx.alibi_slopes,
             rng_state=rng_state,
         )
         dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
         dk = dk[..., : dout.shape[-1]]
         dv = dv[..., : dout.shape[-1]]
-        return dq, dk, dv, None, None, None, None, None, None, None, None, None
+        return dq, dk, dv, None, None, None, None, None, None
 
 
 class FlashAttnVarlenFunc(torch.autograd.Function):
@@ -492,6 +516,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
         softmax_scale,
         causal,
         window_size,
+        alibi_slopes,
         return_softmax,
     ):
         if softmax_scale is None:
@@ -508,6 +533,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
             softmax_scale,
             causal=causal,
             window_size=window_size,
+            alibi_slopes=alibi_slopes,
             return_softmax=return_softmax and dropout_p > 0,
         )
         ctx.save_for_backward(
@@ -519,6 +545,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
         ctx.softmax_scale = softmax_scale
         ctx.causal = causal
         ctx.window_size = window_size
+        ctx.alibi_slopes = alibi_slopes
         return out if not return_softmax else (out, softmax_lse, S_dmask)
 
     @staticmethod
@@ -543,12 +570,13 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
             ctx.softmax_scale,
             ctx.causal,
             ctx.window_size,
+            ctx.alibi_slopes,
             rng_state=rng_state,
         )
         dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
         dk = dk[..., : dout.shape[-1]]
         dv = dv[..., : dout.shape[-1]]
-        return dq, dk, dv, None, None, None, None, None, None, None, None, None
+        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None
 
 
 def flash_attn_qkvpacked_func(
@@ -557,6 +585,7 @@ def flash_attn_qkvpacked_func(
     softmax_scale=None,
     causal=False,
     window_size=(-1, -1),  # -1 means infinite context window
+    alibi_slopes=None,
     return_attn_probs=False,
 ):
     """dropout_p should be set to 0.0 during evaluation
@@ -589,7 +618,7 @@ def flash_attn_qkvpacked_func(
             pattern (negative means that location was dropped, nonnegative means it was kept).
     """
     return FlashAttnQKVPackedFunc.apply(
-        qkv, dropout_p, softmax_scale, causal, window_size, return_attn_probs
+        qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs
     )
 
 
@@ -600,6 +629,7 @@ def flash_attn_kvpacked_func(
     softmax_scale=None,
     causal=False,
     window_size=(-1, -1),  # -1 means infinite context window
+    alibi_slopes=None, 
     return_attn_probs=False,
 ):
     """dropout_p should be set to 0.0 during evaluation
@@ -648,7 +678,7 @@ def flash_attn_kvpacked_func(
             pattern (negative means that location was dropped, nonnegative means it was kept).
     """
     return FlashAttnKVPackedFunc.apply(
-        q, kv, dropout_p, softmax_scale, causal, window_size, return_attn_probs
+        q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs
     )
 
 
@@ -660,6 +690,7 @@ def flash_attn_func(
     softmax_scale=None,
     causal=False,
     window_size=(-1, -1),  # -1 means infinite context window
+    alibi_slopes=None,
     return_attn_probs=False,
 ):
     """dropout_p should be set to 0.0 during evaluation
@@ -706,7 +737,7 @@ def flash_attn_func(
             pattern (negative means that location was dropped, nonnegative means it was kept).
     """
     return FlashAttnFunc.apply(
-        q, k, v, dropout_p, softmax_scale, causal, window_size, return_attn_probs
+        q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs
     )
 
 
@@ -718,6 +749,7 @@ def flash_attn_varlen_qkvpacked_func(
     softmax_scale=None,
     causal=False,
     window_size=(-1, -1),  # -1 means infinite context window
+    alibi_slopes=None,
     return_attn_probs=False,
 ):
     """dropout_p should be set to 0.0 during evaluation
@@ -760,6 +792,7 @@ def flash_attn_varlen_qkvpacked_func(
         softmax_scale,
         causal,
         window_size,
+        alibi_slopes,
         return_attn_probs,
     )
 
@@ -775,6 +808,7 @@ def flash_attn_varlen_kvpacked_func(
     softmax_scale=None,
     causal=False,
     window_size=(-1, -1),  # -1 means infinite context window
+    alibi_slopes=None,
     return_attn_probs=False,
 ):
     """dropout_p should be set to 0.0 during evaluation
@@ -839,6 +873,7 @@ def flash_attn_varlen_kvpacked_func(
         softmax_scale,
         causal,
         window_size,
+        alibi_slopes,
         return_attn_probs,
     )
 
@@ -855,6 +890,7 @@ def flash_attn_varlen_func(
     softmax_scale=None,
     causal=False,
     window_size=(-1, -1),  # -1 means infinite context window
+    alibi_slopes=None,
     return_attn_probs=False,
 ):
     """dropout_p should be set to 0.0 during evaluation
@@ -918,6 +954,7 @@ def flash_attn_varlen_func(
         softmax_scale,
         causal,
         window_size,
+        alibi_slopes,
         return_attn_probs,
     )
 
@@ -937,6 +974,7 @@ def flash_attn_with_kvcache(
     window_size=(-1, -1),  # -1 means infinite context window
     rotary_interleaved=True,
     num_splits=0,
+    alibi_slopes=None,
 ):
     """
     If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
@@ -1041,5 +1079,6 @@ def flash_attn_with_kvcache(
         window_size[1],
         rotary_interleaved,
         num_splits,
+        alibi_slopes
     )
     return out

+ 1006 - 0
tests/test_alibi.py

@@ -0,0 +1,1006 @@
+import math
+
+import pytest
+import torch
+import torch.nn.functional as F
+from einops import rearrange, repeat
+
+from flash_attn import (flash_attn_func, flash_attn_kvpacked_func,
+                        flash_attn_qkvpacked_func, flash_attn_varlen_func,
+                        flash_attn_varlen_kvpacked_func,
+                        flash_attn_varlen_qkvpacked_func,
+                        flash_attn_with_kvcache)
+from flash_attn.bert_padding import pad_input, unpad_input
+from flash_attn.flash_attn_interface import _get_block_size
+from flash_attn.flash_attn_triton import \
+    flash_attn_func as flash_attn_func_triton
+from flash_attn.layers.rotary import apply_rotary_emb
+
+MAX_HEADDIM_SM8x = 192
+
+
+is_sm75 = torch.cuda.get_device_capability("cuda") == (7, 5)
+is_sm8x = torch.cuda.get_device_capability("cuda")[0] == 8
+is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0)
+is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0)
+
+
+def generate_alibi(max_seq_len, num_attention_heads, tp_world_size, tp_index, key_padding_mask=None, device="cuda"):
+    def get_slopes(n):
+        def get_slopes_power_of_2(n):
+            start = (2 ** (-2 ** -(math.log2(n) - 3)))
+            ratio = start
+            return [start * ratio ** i for i in range(n)]
+
+        if math.log2(n).is_integer():
+            return get_slopes_power_of_2(n)
+        else:
+            closest_power_of_2 = 2 ** math.floor(math.log2(n))
+            return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][
+                                                                :n - closest_power_of_2]
+
+    slopes = torch.tensor(get_slopes(num_attention_heads)).to(device=device)
+    # Select the part of the tensor that corresponds to our tensor parallel index.
+    assert (num_attention_heads/tp_world_size).is_integer(
+    ), "it works only when (num_attention_heads/tp_world_size) is integer"
+    nh_tp = num_attention_heads // tp_world_size
+    slopes = slopes[nh_tp * tp_index:nh_tp * (tp_index + 1)]
+
+    if (key_padding_mask is None):
+        arange_tensor = rearrange(torch.arange(max_seq_len), "sqk -> 1 sqk").to(device=device)
+    else:
+        arange_tensor = (key_padding_mask.cumsum(dim=-1, dtype=slopes.dtype) - 1) \
+            .masked_fill_(~key_padding_mask, torch.finfo(torch.float).min).to(device=device)
+        
+    arange_tensor = rearrange(arange_tensor, 'b sqk -> b 1 1 sqk')
+    # (1, nheads, 1, seqlen_k) or (batch, nheads, 1, seqlen_k)
+    alibi_tensor = rearrange(slopes, 'nh -> 1 nh 1 1') * arange_tensor
+
+    return alibi_tensor, slopes
+    
+
+def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", right_padding=True):
+    assert mode in ["full", "random", "third"]
+    if mode == "full":
+        lengths = torch.full((batch_size, 1), max_seqlen,
+                             device=device, dtype=torch.int32)
+    elif mode == "random":
+        lengths = torch.randint(
+            max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
+        )
+    elif mode == "third":
+        lengths = torch.randint(
+            max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
+    if right_padding:
+        padding_mask = (
+            repeat(torch.arange(max_seqlen, device=device),
+                "s -> b s", b=batch_size) < lengths
+        )
+    else:
+        padding_mask = (
+            repeat(torch.arange(start=max_seqlen-1, end=-1, step=-1, device=device),
+                "s -> b s", b=batch_size) < lengths
+        )
+    return padding_mask
+
+
+def generate_qkv(
+    q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False
+):
+    """
+    Arguments:
+        q: (batch_size, seqlen_q, nheads, d)
+        k: (batch_size, seqlen_k, nheads_k, d)
+        v: (batch_size, seqlen_k, nheads_k, d)
+        query_padding_mask: (batch_size, seqlen), bool
+        key_padding_mask: (batch_size, seqlen), bool
+    """
+    assert not (kvpacked and qkvpacked)
+    batch_size, seqlen_q, nheads, d = q.shape
+    _, seqlen_k, nheads_k, _ = k.shape
+    assert k.shape == (batch_size, seqlen_k, nheads_k, d)
+    assert v.shape == (batch_size, seqlen_k, nheads_k, d)
+
+    if query_padding_mask is not None:
+        q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
+            q, query_padding_mask)
+
+        def output_pad_fn(output_unpad): return pad_input(
+            output_unpad, indices_q, batch_size, seqlen_q
+        )
+    else:
+        q_unpad = rearrange(q, "b s h d -> (b s) h d")
+        cu_seqlens_q = torch.arange(
+            0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
+        )
+        max_seqlen_q = seqlen_q
+
+        def output_pad_fn(output_unpad): return rearrange(
+            output_unpad, "(b s) h d -> b s h d", b=batch_size
+        )
+
+    if key_padding_mask is not None:
+        k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(
+            k, key_padding_mask)
+        v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
+    else:
+        k_unpad = rearrange(k, "b s h d -> (b s) h d")
+        v_unpad = rearrange(v, "b s h d -> (b s) h d")
+        cu_seqlens_k = torch.arange(
+            0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
+        )
+        max_seqlen_k = seqlen_k
+
+    if qkvpacked:
+        assert (query_padding_mask == key_padding_mask).all()
+        assert nheads == nheads_k
+        qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
+        qkv = torch.stack([q, k, v], dim=2)
+        if query_padding_mask is not None:
+            def dqkv_pad_fn(dqkv_unpad): return pad_input(
+                dqkv_unpad, indices_q, batch_size, seqlen_q)
+        else:
+            def dqkv_pad_fn(dqkv_unpad): return rearrange(
+                dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
+            )
+        return (
+            qkv_unpad.detach().requires_grad_(),
+            cu_seqlens_q,
+            max_seqlen_q,
+            qkv.detach().requires_grad_(),
+            output_pad_fn,
+            dqkv_pad_fn,
+        )
+    elif kvpacked:
+        kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
+        kv = torch.stack([k, v], dim=2)
+        dq_pad_fn = output_pad_fn
+        if key_padding_mask is not None:
+            def dkv_pad_fn(dkv_unpad): return pad_input(
+                dkv_unpad, indices_k, batch_size, seqlen_k)
+        else:
+            def dkv_pad_fn(dkv_unpad): return rearrange(
+                dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
+            )
+        return (
+            q_unpad.detach().requires_grad_(),
+            kv_unpad.detach().requires_grad_(),
+            cu_seqlens_q,
+            cu_seqlens_k,
+            max_seqlen_q,
+            max_seqlen_k,
+            q.detach().requires_grad_(),
+            kv.detach().requires_grad_(),
+            output_pad_fn,
+            dq_pad_fn,
+            dkv_pad_fn,
+        )
+    else:
+        dq_pad_fn = output_pad_fn
+        if key_padding_mask is not None:
+            def dk_pad_fn(dk_unpad): return pad_input(
+                dk_unpad, indices_k, batch_size, seqlen_k)
+        else:
+            def dk_pad_fn(dk_unpad): return rearrange(
+                dk_unpad, "(b s) h d -> b s h d", b=batch_size)
+        return (
+            q_unpad.detach().requires_grad_(),
+            k_unpad.detach().requires_grad_(),
+            v_unpad.detach().requires_grad_(),
+            cu_seqlens_q,
+            cu_seqlens_k,
+            max_seqlen_q,
+            max_seqlen_k,
+            q.detach().requires_grad_(),
+            k.detach().requires_grad_(),
+            v.detach().requires_grad_(),
+            output_pad_fn,
+            dq_pad_fn,
+            dk_pad_fn,
+        )
+
+
+def construct_local_mask(
+    seqlen_q,
+    seqlen_k,
+    window_size=(-1, -1),  # -1 means infinite window size
+    query_padding_mask=None,
+    key_padding_mask=None,
+    device=None,
+):
+    row_idx = rearrange(torch.arange(
+        seqlen_q, device=device, dtype=torch.long), "s -> s 1")
+    col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
+    sk = (
+        seqlen_k
+        if key_padding_mask is None
+        else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
+    )
+    sq = (
+        seqlen_q
+        if query_padding_mask is None
+        else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
+    )
+    if window_size[0] < 0:
+        return col_idx > row_idx + sk - sq + window_size[1]
+    else:
+        sk = torch.full_like(
+            col_idx, seqlen_k) if key_padding_mask is None else sk
+        return torch.logical_or(
+            col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
+            col_idx < row_idx + sk - sq - window_size[0],
+        )
+
+
+def attention_ref(
+    q,
+    k,
+    v,
+    query_padding_mask=None,
+    key_padding_mask=None,
+    dropout_p=0.0,
+    dropout_mask=None,
+    causal=False,
+    window_size=(-1, -1),  # -1 means infinite window size
+    upcast=True,
+    reorder_ops=False,
+    bias=None
+):
+    """
+    Arguments:
+        q: (batch_size, seqlen_q, nheads, head_dim)
+        k: (batch_size, seqlen_k, nheads_k, head_dim)
+        v: (batch_size, seqlen_k, nheads_k, head_dim)
+        query_padding_mask: (batch_size, seqlen_q)
+        key_padding_mask: (batch_size, seqlen_k)
+        dropout_p: float
+        dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
+        causal: whether to apply causal masking
+        window_size: (int, int), left and right window size
+        upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
+            output back to fp16/bf16.
+        reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
+            without changing the math. This is to estimate the numerical error from operation
+            reordering.
+    Output:
+        output: (batch_size, seqlen_q, nheads, head_dim)
+        attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
+    """
+    if causal:
+        window_size = (window_size[0], 0)
+    dtype_og = q.dtype
+    if upcast:
+        q, k, v = q.float(), k.float(), v.float()
+    seqlen_q, seqlen_k = q.shape[1], k.shape[1]
+    k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
+    v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
+    d = q.shape[-1]
+    if not reorder_ops:
+        scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
+    else:
+        scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
+    if bias is not None:
+        bias = bias.to(scores.dtype)
+        scores += bias
+    if key_padding_mask is not None:
+        scores.masked_fill_(rearrange(~key_padding_mask,
+                            "b s -> b 1 1 s"), float("-inf"))
+    if window_size[0] >= 0 or window_size[1] >= 0:
+        local_mask = construct_local_mask(
+            seqlen_q,
+            seqlen_k,
+            window_size,
+            query_padding_mask,
+            key_padding_mask,
+            q.device,
+        )
+        scores.masked_fill_(local_mask, float("-inf"))
+    attention = torch.softmax(scores, dim=-1)
+    # Some rows might be completely masked out so we fill them with zero instead of NaN
+    if window_size[0] >= 0 or window_size[1] >= 0:
+        attention = attention.masked_fill(
+            torch.all(local_mask, dim=-1, keepdim=True), 0.0)
+    # We want to mask here so that the attention matrix doesn't have any NaNs
+    # Otherwise we'll get NaN in dV
+    if query_padding_mask is not None:
+        attention = attention.masked_fill(
+            rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
+    dropout_scaling = 1.0 / (1 - dropout_p)
+    # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
+    # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
+    if dropout_mask is not None:
+        attention_drop = attention.masked_fill(~dropout_mask, 0.0)
+    else:
+        attention_drop = attention
+    output = torch.einsum(
+        "bhts,bshd->bthd", attention_drop, v * dropout_scaling)
+    if query_padding_mask is not None:
+        output.masked_fill_(
+            rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
+    return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
+
+
+def attention_kvpacked_ref(
+    q,
+    kv,
+    query_padding_mask=None,
+    key_padding_mask=None,
+    dropout_p=0.0,
+    dropout_mask=None,
+    causal=False,
+    window_size=(-1, -1),  # -1 means infinite window size
+    upcast=True,
+    reorder_ops=False,
+):
+    return attention_ref(
+        q,
+        kv[:, :, 0],
+        kv[:, :, 1],
+        query_padding_mask,
+        key_padding_mask,
+        dropout_p,
+        dropout_mask,
+        upcast=upcast,
+        causal=causal,
+        window_size=window_size,
+        reorder_ops=reorder_ops,
+    )
+
+
+def attention_qkvpacked_ref(
+    qkv,
+    key_padding_mask=None,
+    dropout_p=0.0,
+    dropout_mask=None,
+    causal=False,
+    window_size=(-1, -1),  # -1 means infinite window size
+    upcast=True,
+    reorder_ops=False,
+):
+    return attention_ref(
+        qkv[:, :, 0],
+        qkv[:, :, 1],
+        qkv[:, :, 2],
+        key_padding_mask,
+        key_padding_mask,
+        dropout_p,
+        dropout_mask,
+        upcast=upcast,
+        causal=causal,
+        window_size=window_size,
+        reorder_ops=reorder_ops,
+    )
+
+
+def generate_sparsity_mask(seqlen, sparsity=0.3):
+    repeats = seqlen // 16 // 2
+    # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda'),
+    #                     torch.tensor([0, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
+    # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda'),
+    #                     torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
+    # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
+    # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
+    nrow, ncol = seqlen // 16, seqlen // 256
+    mask = torch.rand(nrow, ncol, device="cuda") < sparsity
+    return mask
+
+
+def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask):
+    """
+    Arguments:
+        qkv: (batch_size, seqlen, 3, nheads, head_dim)
+        blockmask: (seqlen / 16, seqlen / 256)
+        attn_mask: (batch_size, seqlen)
+        dropout_p: float
+        dropout_mask: (batch_size, nheads, seqlen, seqlen)
+    Output:
+        output: (batch_size, seqlen, nheads, head_dim)
+        attention: softmax after dropout
+    """
+    q, k, v = qkv.float().unbind(dim=2)
+    d = qkv.shape[-1]
+    seqlen = qkv.shape[1]
+    scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
+    scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf"))
+    blockmask = repeat(blockmask, "s_16 s_256 -> (s_16 16) (s_256 256)")
+    blockmask = blockmask[:seqlen, :seqlen]
+    scores.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), float("-inf"))
+    attention = torch.softmax(scores, dim=-1)
+    attention = attention.masked_fill(
+        rearrange(~attn_mask, "b s -> b 1 s 1"), 0.0)
+    attention = attention.masked_fill_(
+        rearrange(~blockmask, "t s -> 1 1 t s"), 0.0)
+    attention_drop = attention.masked_fill(
+        ~dropout_mask, 0.0) / (1 - dropout_p)
+    output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
+    output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1 1"), 0)
+    return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype)
+
+
+def convert_flash_attn_S_to_softmax(
+    S,
+    seqlen_q,
+    seqlen_k,
+    query_padding_mask,
+    key_padding_mask,
+    head_dim,
+    is_dropout,
+    causal=False,
+    window_size=(-1, -1),  # -1 means infinite window size
+):
+    """FlashAttention stores the S matrix in a different way.
+    Arguments:
+        S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded)
+        query_padding_mask: (batch_size, seqlen_q_rounded)
+        key_padding_mask: (batch_size, seqlen_k_rounded)
+    """
+    if causal:
+        window_size = (window_size[0], 0)
+    seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:]
+    warps_n = 4
+    blocksize_m, blocksize_n = _get_block_size(
+        S.device, head_dim, is_dropout, causal)
+    nblocks_n = (seqlen_k_rounded + blocksize_n - 1) // blocksize_n
+    nblocks_m = (seqlen_q_rounded + blocksize_m - 1) // blocksize_m
+    mmas_n = (blocksize_n + 16 - 1) // 16
+    S_flat = rearrange(
+        S,
+        "b h (nblocks_m blocksize_m) (nblocks_n blocksize_n) -> b h nblocks_m nblocks_n (blocksize_m blocksize_n)",
+        blocksize_m=blocksize_m,
+        blocksize_n=blocksize_n,
+    )
+    S_converted = rearrange(
+        S_flat,
+        "b h nblocks_m nblocks_n (mmas_n mmas_m warps_n eight four c2 c1 c0) -> b h (nblocks_m mmas_m warps_n c1 eight) (nblocks_n mmas_n c2 four c0)",
+        mmas_n=mmas_n,
+        warps_n=warps_n,
+        eight=8,
+        c0=2,
+        c1=2,
+        c2=2,
+        four=4,
+    )
+
+    if window_size[0] >= 0 or window_size[1] >= 0:
+        local_mask = construct_local_mask(
+            seqlen_q,
+            seqlen_k,
+            window_size,
+            query_padding_mask,
+            key_padding_mask,
+            S.device,
+        )
+        local_mask = F.pad(
+            local_mask,
+            (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q),
+            value=True,
+        )
+        S_converted.masked_fill_(local_mask, 0.0)
+
+    # Need to zero out things not in attention_mask in case S was initialized with random values
+    # and some of those values aren't overwritten.
+    seqlen_q_og = (
+        query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded
+    )
+    if query_padding_mask is not None:
+        query_padding_mask = F.pad(
+            query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og))
+        S_converted = S_converted.masked_fill(
+            rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
+    seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k
+    if key_padding_mask is not None:
+        key_padding_mask = F.pad(
+            key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og))
+        S_converted = S_converted.masked_fill(
+            rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
+    S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded))
+    S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded))
+    return S_converted[:, :, :seqlen_q, :seqlen_k]
+
+
+def normalize_flash_attn_S(
+    attn_unnorm,
+    q,
+    k,
+    v,
+    query_padding_mask=None,
+    key_padding_mask=None,
+    is_dropout=False,
+    causal=False,
+    window_size=(-1, -1),  # -1 means infinite window size
+):
+    """
+    Arguments:
+        q: (batch_size, seqlen_q, nheads, head_dim)
+        k, v: (batch_size, seqlen_k, nheads, head_dim)
+        key_padding_mask: (batch_size, seqlen_q)
+    Output:
+        softmax_lse: (batch_size, nheads, seqlen_q)
+        softmax_max: (batch_size, nheads, seqlen_q)
+    """
+    if causal:
+        window_size = (window_size[0], 0)
+    q, k, v = q.float(), k.float(), v.float()
+    _, seqlen_q, _, head_dim = q.shape
+    seqlen_k = k.shape[1]
+    scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(head_dim), k)
+    if key_padding_mask is not None:
+        scores.masked_fill_(rearrange(~key_padding_mask,
+                            "b s -> b 1 1 s"), float("-inf"))
+    if window_size[0] >= 0 or window_size[1] >= 0:
+        local_mask = construct_local_mask(
+            seqlen_q,
+            seqlen_k,
+            window_size,
+            query_padding_mask,
+            key_padding_mask,
+            q.device,
+        )
+        scores.masked_fill_(local_mask, float("-inf"))
+    _, block_size_n = _get_block_size(
+        scores.device, head_dim, is_dropout, causal)
+    scores_block = scores.split(block_size_n, dim=-1)
+    lse_block = torch.stack([torch.logsumexp(s, dim=-1)
+                            for s in scores_block], dim=-1)
+    lse = torch.logsumexp(lse_block, dim=-1)
+    # lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf
+    # so that when we do torch.exp(m - lse), we get 0.0 instead of NaN.
+    lse[lse == float("-inf")] = float("inf")
+    scores_max_block = torch.stack(
+        [torch.amax(s, dim=-1) for s in scores_block], dim=-1)
+    cummax_block = torch.cummax(
+        scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1)
+    attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1)
+    attn_norm = torch.cat(
+        [
+            a * rearrange(torch.exp(m - lse), "b h s -> b h s 1")
+            for a, m in zip(attn_unnorm_block, cummax_block)
+        ],
+        dim=-1,
+    )
+    if query_padding_mask is not None:
+        attn_norm.masked_fill_(
+            rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
+    return attn_norm.to(dtype=attn_unnorm.dtype)
+
+
+def get_dropout_fraction(
+    dropout_mask,
+    query_padding_mask=None,
+    key_padding_mask=None,
+    causal=False,
+    window_size=(-1, -1),  # -1 means infinite window size
+):
+    """
+    dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.
+    query_padding_mask: (batch_size, seqlen_q)
+    key_padding_mask: (batch_size, seqlen_k)
+    """
+    if causal:
+        window_size = (window_size[0], 0)
+    batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape
+    dropped = ~dropout_mask
+    valid = torch.ones_like(dropout_mask)
+    if query_padding_mask is not None:
+        dropped.masked_fill_(
+            rearrange(~query_padding_mask, "b s -> b 1 s 1"), False)
+        valid.masked_fill_(
+            rearrange(~query_padding_mask, "b s -> b 1 s 1"), False)
+    if key_padding_mask is not None:
+        dropped.masked_fill_(
+            rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
+        valid.masked_fill_(
+            rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
+    if window_size[0] >= 0 or window_size[1] >= 0:
+        local_mask = construct_local_mask(
+            seqlen_q,
+            seqlen_k,
+            window_size,
+            query_padding_mask,
+            key_padding_mask,
+            dropout_mask.device,
+        )
+        dropped.masked_fill_(local_mask, False)
+        valid.masked_fill_(local_mask, False)
+    dropped_total = dropped.sum()
+    return dropped.sum() / valid.sum()
+
+
+@pytest.mark.parametrize(
+    "dtype", [torch.float16]
+)
+@pytest.mark.parametrize(
+    "b_sq", 
+    [
+        (32, 512), 
+        (16, 1024), 
+        (8, 2048),
+        (4, 4096), 
+        (2, 8192), 
+        (1, 16384)
+    ]
+)
+@pytest.mark.parametrize(
+    "nh_hd", 
+    [
+        (32, 64),
+        (16, 128),
+        (40, 128) # non power of 2 nh
+    ]
+)
+@pytest.mark.parametrize(
+    "tp_world_size", [1, 2, 4]
+)
+def test_flash_attn_func(b_sq, nh_hd, tp_world_size, dtype):
+    b, sq = b_sq
+    nh, hd = nh_hd
+    nh_tp = nh // tp_world_size
+    q, k, v = [torch.randn(b, sq, nh_tp, hd, device="cuda", 
+                           dtype=dtype, requires_grad=True) for _ in range(3)]
+    dout = torch.rand_like(q)
+
+    for tp_index in range(tp_world_size):
+        alibi, alibi_slopes = generate_alibi(
+            max_seq_len=sq,
+            num_attention_heads=nh,
+            tp_world_size=tp_world_size,
+            tp_index=tp_index,
+            key_padding_mask=None,
+            device="cuda"
+        )
+
+        triton_out = flash_attn_func_triton(
+            q, k, v, alibi, True, hd**(-0.5))
+        triton_out.backward(dout)
+        triton_dq, q.grad = q.grad.clone(), None
+        triton_dk, k.grad = k.grad.clone(), None
+        triton_dv, v.grad = v.grad.clone(), None
+
+        flash_out = flash_attn_func(q, k, v, causal=True, alibi_slopes=repeat(alibi_slopes, "nh -> b nh", b=b))
+        flash_out.backward(dout)
+        flash_dq, q.grad = q.grad.clone(), None
+        flash_dk, k.grad = k.grad.clone(), None
+        flash_dv, v.grad = v.grad.clone(), None
+
+        assert torch.allclose(flash_out, triton_out, atol=1e-2, rtol=0.)
+        assert torch.allclose(flash_dq, triton_dq, atol=1e-2, rtol=0.)
+        assert torch.allclose(flash_dk, triton_dk, atol=1e-2, rtol=0.)
+        assert torch.allclose(flash_dv, triton_dv, atol=1e-2, rtol=0.)
+
+
+@pytest.mark.parametrize(
+    "dtype", [torch.float16]
+)
+@pytest.mark.parametrize(
+    "right_padding", [True, False]
+)
+@pytest.mark.parametrize(
+    "b_sq", 
+    [
+        (32, 512), 
+        (16, 1024), 
+        (8, 2048),
+        (4, 4096), 
+        (2, 8192), 
+        (1, 16384)
+    ]
+)
+@pytest.mark.parametrize(
+    "nh_hd", 
+    [
+        (32, 64),
+        (16, 128),
+        (40, 128) # non power of 2 nh
+    ]
+)
+@pytest.mark.parametrize(
+    "tp_world_size", [1, 2, 4]
+)
+def test_flash_attn_varlen_func(b_sq, nh_hd, tp_world_size, right_padding, dtype):
+    b, sqk = b_sq
+    nh, hd = nh_hd
+    nh_tp = nh // tp_world_size
+    # flash_attn_func_triton(), flash-attention v2 (above v2.1) causal logic are different
+    # so only (seqlen_q == 1, causal=False to triton ver.) shows correct results
+    # https://github.com/huggingface/text-generation-inference/blob/v1.1.1/server/text_generation_server/models/custom_modeling/mpt_modeling.py#L53-L63
+    q = torch.randn(b, 1, nh_tp, hd, device="cuda", dtype=dtype, requires_grad=True)
+    k, v = [torch.randn(b, sqk, nh_tp, hd, device="cuda",
+                        dtype=dtype, requires_grad=True) for _ in range(2)]
+    dout = torch.rand_like(q)
+
+    padding_mask = generate_random_padding_mask(sqk, b, "cuda", "random", right_padding)
+    (
+        q_unpad,
+        k_unpad,
+        v_unpad,
+        cu_seqlens_q,
+        cu_seqlens_k,
+        max_seqlen_q,
+        max_seqlen_k,
+        q,
+        k,
+        v,
+        output_pad_fn,
+        dq_pad_fn,
+        dk_pad_fn,
+    ) = generate_qkv(q, k, v, None, padding_mask, kvpacked=False)
+
+    for tp_index in range(tp_world_size):
+        alibi, alibi_slopes = generate_alibi(
+            max_seq_len=sqk,
+            num_attention_heads=nh,
+            tp_world_size=tp_world_size,
+            tp_index=tp_index,
+            key_padding_mask=padding_mask,
+            device="cuda"
+        )
+
+        triton_out = flash_attn_func_triton(
+            q, k, v, alibi, False, hd**(-0.5))
+        triton_out.backward(dout)
+        triton_dq, q.grad = q.grad.clone(), None
+        triton_dk, k.grad = k.grad.clone(), None
+        triton_dv, v.grad = v.grad.clone(), None
+
+        flash_out_unpad = flash_attn_varlen_func(
+            q_unpad,
+            k_unpad,
+            v_unpad,
+            cu_seqlens_q,
+            cu_seqlens_k,
+            max_seqlen_q,
+            max_seqlen_k,
+            causal=True,
+            alibi_slopes=repeat(alibi_slopes, "nh -> b nh", b=b)
+        )
+        flash_out = output_pad_fn(flash_out_unpad)
+        flash_out.backward(dout)
+        flash_dq_unpad, q_unpad.grad = q_unpad.grad.clone(), None
+        flash_dk_unpad, k_unpad.grad = k_unpad.grad.clone(), None
+        flash_dv_unpad, v_unpad.grad = v_unpad.grad.clone(), None
+        flash_dq = dq_pad_fn(flash_dq_unpad)
+        flash_dk = dk_pad_fn(flash_dk_unpad)
+        flash_dv = dk_pad_fn(flash_dv_unpad)
+
+        assert torch.allclose(flash_out, triton_out, atol=1e-2, rtol=0.)
+        assert torch.allclose(flash_dq, triton_dq, atol=1e-2, rtol=0.)
+        assert torch.allclose(flash_dk, triton_dk, atol=1e-2, rtol=0.)
+        assert torch.allclose(flash_dv, triton_dv, atol=1e-2, rtol=0.)
+
+
+@pytest.mark.parametrize("alibi", [True])
+@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
+# @pytest.mark.parametrize("dtype", [torch.float16])
+@pytest.mark.parametrize("num_splits", [1, 0])
+# @pytest.mark.parametrize("num_splits", [0])
+@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
+# @pytest.mark.parametrize("mha_type", ["mha"])
+@pytest.mark.parametrize("new_kv", [False, True])
+# @pytest.mark.parametrize("new_kv", [True])
+# @pytest.mark.parametrize("local", [False, True])
+@pytest.mark.parametrize("local", [False])
+# @pytest.mark.parametrize("causal", [False, True])
+@pytest.mark.parametrize("causal", [True])
+@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
+# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
+@pytest.mark.parametrize("rotary_interleaved", [False, True])
+# @pytest.mark.parametrize("rotary_interleaved", [False])
+@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
+# @pytest.mark.parametrize("rotary_fraction", [0.0])
+@pytest.mark.parametrize("has_batch_idx", [False, True])
+# @pytest.mark.parametrize("has_batch_idx", [True])
+@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
+# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
+# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
+# @pytest.mark.parametrize('d', [56, 80])
+# @pytest.mark.parametrize("d", [128])
+@pytest.mark.parametrize(
+    "seqlen_q,seqlen_k",
+    [
+        (1, 128),
+        (1, 339),
+        (3, 1024),
+        (64, 800),
+        (64, 256),
+        (3, 799),
+        (64, 2048),
+        (16, 20000),
+        (1, 128 * 1024),
+        (16, 128 * 1024),
+        (128, 128),
+    ],
+)
+# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
+def test_flash_attn_kvcache(
+    seqlen_q,
+    seqlen_k,
+    d,
+    has_batch_idx,
+    rotary_fraction,
+    rotary_interleaved,
+    seqlen_new_eq_seqlen_q,
+    causal,
+    local,
+    new_kv,
+    mha_type,
+    num_splits,
+    dtype,
+    alibi,
+):
+    if seqlen_q > seqlen_k and new_kv:
+        pytest.skip()
+    if not new_kv and rotary_fraction > 0.0:
+        pytest.skip()
+    device = "cuda"
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 2
+    batch_size_cache = batch_size if not has_batch_idx else batch_size * 2
+    nheads = 8
+    # rotary_dim must be a multiple of 16, and must be <= d
+    rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
+    nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 4)
+    assert nheads % nheads_k == 0
+    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
+    q = torch.randn(batch_size, seqlen_q, nheads,
+                    d, device=device, dtype=dtype)
+    seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(
+        1, seqlen_q + 1, (1,)).item()
+    if new_kv:
+        k = torch.randn(batch_size, seqlen_new, nheads_k,
+                        d, device=device, dtype=dtype)
+        v = torch.randn(batch_size, seqlen_new, nheads_k,
+                        d, device=device, dtype=dtype)
+    else:
+        k, v = None, None
+    k_cache = torch.randn(batch_size_cache, seqlen_k,
+                          nheads_k, d, device=device, dtype=dtype)
+    v_cache = torch.randn(batch_size_cache, seqlen_k,
+                          nheads_k, d, device=device, dtype=dtype)
+    cache_seqlens = torch.randint(
+        0,
+        # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
+        (seqlen_k - (seqlen_q if (causal or local)
+         and rotary_dim > 1 else seqlen_new) + 1)
+        if new_kv
+        else (seqlen_k + 1),
+        (batch_size,),
+        dtype=torch.int32,
+        device=device,
+    )
+    if has_batch_idx:
+        cache_batch_idx = torch.randperm(
+            batch_size_cache, dtype=torch.int32, device=device)[:batch_size]
+    else:
+        cache_batch_idx = None
+    # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
+    if rotary_dim > 0:
+        angle = torch.rand(seqlen_k, rotary_dim // 2,
+                           device=device) * 2 * math.pi
+        cos = torch.cos(angle).to(dtype=dtype)
+        sin = torch.sin(angle).to(dtype=dtype)
+        if causal or local:
+            q_ro = apply_rotary_emb(
+                q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
+            )
+        else:
+            q_ro = rearrange(
+                apply_rotary_emb(
+                    rearrange(q, "b s h d -> b 1 (s h) d"),
+                    cos,
+                    sin,
+                    seqlen_offsets=cache_seqlens,
+                    interleaved=rotary_interleaved,
+                ),
+                "b 1 (s h) d -> b s h d",
+                s=seqlen_q,
+            )
+        # q_ro = q
+        k_ro = apply_rotary_emb(
+            k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
+        )
+    else:
+        cos, sin = None, None
+        q_ro, k_ro = q, k
+    # k_cache[:, 64:] = -1
+    k_cache_ref = (
+        k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone()
+    v_cache_ref = (
+        v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone()
+    arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
+    cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
+    if new_kv:
+        update_mask = torch.logical_and(
+            cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
+        )
+        k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...")
+        v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...")
+    k_cache_rep = repeat(
+        k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
+    v_cache_rep = repeat(
+        v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
+    if alibi:
+        seqlen_alibi = k_cache_rep.shape[1]
+        alibi_tensor, alibi_slopes = generate_alibi(
+            max_seq_len=seqlen_alibi,
+            num_attention_heads=nheads,
+            tp_world_size=1,
+            tp_index=0,
+            key_padding_mask=None,
+            device="cuda"
+        )
+        # alibi_tensor = alibi_tensor.expand(batch_size, -1, seqlen_q, -1)
+        alibi_slopes = repeat(alibi_slopes, "nh -> b nh", b=batch_size)
+        if alibi_tensor.abs().max().item() >= torch.finfo(dtype).max:
+            pytest.skip()
+    else:
+        alibi_tensor, alibi_slopes = None, None
+    out = flash_attn_with_kvcache(
+        q,
+        k_cache,
+        v_cache,
+        k,
+        v,
+        cos,
+        sin,
+        cache_seqlens,
+        cache_batch_idx,
+        causal=causal,
+        window_size=window_size,
+        rotary_interleaved=rotary_interleaved,
+        num_splits=num_splits,
+        alibi_slopes=alibi_slopes
+    )
+    # out = flash_attn_with_kvcache(
+    #     q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size
+    # )
+    # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)
+    # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
+    # m = qk.amax(-1, keepdim=True)
+    # s_tmp = torch.exp((qk - m) / math.sqrt(d))
+    # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
+    # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
+    # probs = torch.softmax(qk, dim=-1)
+    key_padding_mask = arange < cache_seqlens_expanded + \
+        (seqlen_new if new_kv else 0)
+    out_ref, _ = attention_ref(
+        q_ro,
+        k_cache_rep,
+        v_cache_rep,
+        None,
+        key_padding_mask,
+        0.0,
+        None,
+        causal=causal,
+        window_size=window_size,
+        bias=alibi_tensor
+    )
+    out_pt, _ = attention_ref(
+        q_ro,
+        k_cache_rep,
+        v_cache_rep,
+        None,
+        key_padding_mask,
+        0.0,
+        None,
+        causal=causal,
+        window_size=window_size,
+        upcast=False,
+        reorder_ops=True,
+        bias=alibi_tensor
+    )
+    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
+    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
+    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
+    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
+
+    # Check that FlashAttention's numerical error is at most twice the numerical error
+    # of a Pytorch implementation.
+    if new_kv:
+        k_cache_select = k_cache if not has_batch_idx else k_cache[cache_batch_idx]
+        v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx]
+        assert torch.allclose(k_cache_select, k_cache_ref,
+                              rtol=1e-3, atol=1e-3)
+        assert torch.equal(v_cache_select, v_cache_ref)
+    assert (out - out_ref).abs().max().item() <= 3 * \
+        (out_pt - out_ref).abs().max().item() + 1e-5