Ver Fonte

Support unpadded LSE layout (#970)

* Support unpadded LSE layout.

Co-authored-by: Xinfeng Xie <xfxie.ceca@gmail.com>
Co-authored-by: Jianyu Huang <hjyahead@gmail.com>

* Cleanup

* Fix unpadded LSE on split-kv path

* Fix formatting and comments

* Fix inline vs forceinline

---------

Co-authored-by: Xinfeng Xie <xfxie.ceca@gmail.com>
Co-authored-by: Jianyu Huang <hjyahead@gmail.com>
Grigory Sizov há 9 meses atrás
pai
commit
f816dee63c

+ 23 - 11
csrc/flash_attn/flash_api.cpp

@@ -43,7 +43,8 @@ void set_params_fprop(Flash_fwd_params &params,
                       float softmax_scale,
                       int window_size_left,
                       int window_size_right,
-                      bool seqlenq_ngroups_swapped=false) {
+                      bool seqlenq_ngroups_swapped=false,
+                      const bool unpadded_lse=false) {
 
     // Reset the parameters
     params = {};
@@ -135,6 +136,9 @@ void set_params_fprop(Flash_fwd_params &params,
     #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
         TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
     #endif
+
+    params.unpadded_lse = unpadded_lse;
+    params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;
 }
 
 void set_params_dgrad(Flash_bwd_params &params,
@@ -168,7 +172,8 @@ void set_params_dgrad(Flash_bwd_params &params,
                       float softmax_scale,
                       int window_size_left,
                       int window_size_right,
-                      bool deterministic) {
+                      bool deterministic,
+                      const bool unpadded_lse) {
 
     set_params_fprop(params,
                      b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
@@ -181,7 +186,9 @@ void set_params_dgrad(Flash_bwd_params &params,
                      p_dropout,
                      softmax_scale,
                      window_size_left,
-                     window_size_right);
+                     window_size_right,
+                     false, // seqlenq_ngroups_swapped
+                     unpadded_lse);
 
     // Set the pointers and strides.
     params.do_ptr = dout.data_ptr();
@@ -651,8 +658,7 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
     at::cuda::CUDAGuard device_guard{(char)q.get_device()};
 
     auto opts = q.options();
-
-    auto softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
+    auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
     at::Tensor p;
     // Only return softmax if there's dropout to reduce compilation time
     if (return_softmax) {
@@ -683,7 +689,9 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
                      softmax_scale,
                      window_size_left,
                      window_size_right,
-                     seqlenq_ngroups_swapped);
+                     seqlenq_ngroups_swapped,
+                     /*unpadded_lse*/true);
+    params.total_q = total_q;
 
     if (paged_KV) {
         params.block_table = block_table.data_ptr<int>();
@@ -739,7 +747,7 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
         out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
         out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
         q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
-        softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * max_seqlen_q, 1});
+        softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
     }
 
     return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
@@ -933,7 +941,8 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
                      softmax_scale,
                      window_size_left,
                      window_size_right,
-                     deterministic);
+                     deterministic,
+                     /*unpadded_lse*/false);
     params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
 
     auto launch = &run_mha_bwd;
@@ -986,7 +995,7 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
                const at::Tensor &k,   // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
                const at::Tensor &v,   // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
                const at::Tensor &out,   // total_q x num_heads x head_size
-               const at::Tensor &softmax_lse,     // b x h x s   softmax logsumexp
+               const at::Tensor &softmax_lse,    // h x total_q, softmax logsumexp
                c10::optional<at::Tensor> &dq_,   // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
                c10::optional<at::Tensor> &dk_,   // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
                c10::optional<at::Tensor> &dv_,   // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
@@ -1126,7 +1135,7 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
     at::cuda::CUDAGuard device_guard{(char)q.get_device()};
 
     auto opts = q.options();
-    auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
+    auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat));
     at::Tensor dq_accum;
     if (loop) {
         // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded)
@@ -1137,6 +1146,7 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
         // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
         // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
         // allowed to do. So we won't have to do any bound checking, and performance should stay the same.
+        // Same holds for softmax_d, since LSE is stored in unpadded format.
         if (!deterministic) {
             dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
         } else {
@@ -1182,8 +1192,10 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
                      softmax_scale,
                      window_size_left,
                      window_size_right,
-                     deterministic);
+                     deterministic,
+                     /*unpadded_lse*/true);
     params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
+    params.total_q = total_q;
 
     auto launch = &run_mha_bwd;
 

+ 4 - 1
csrc/flash_attn/src/flash.h

@@ -67,7 +67,7 @@ struct Flash_fwd_params : public Qkv_params {
     void * __restrict__ softmax_lseaccum_ptr;
 
     // The dimensions.
-    int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
+    int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q;
 
     // The scaling factors for the kernel.
     float scale_softmax;
@@ -138,6 +138,9 @@ struct Flash_fwd_params : public Qkv_params {
 
     void * __restrict__ alibi_slopes_ptr;
     index_t alibi_slopes_batch_stride;
+
+    bool unpadded_lse;  // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
+    bool seqlenq_ngroups_swapped;  // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).
 };
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////

+ 3 - 4
csrc/flash_attn/src/flash_bwd_kernel.h

@@ -120,10 +120,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
         + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded
         // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
         + (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride);
-    const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q
-        + (m_block_max - 1) * kBlockM;
-    const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded
-        + (m_block_max - 1) * kBlockM;
+    const index_t row_offset_lse = (params.unpadded_lse? bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb): (bidb * params.h + bidh) * params.seqlen_q) + (m_block_max - 1) * kBlockM;
+    // Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d
+    const index_t row_offset_dpsum = (params.unpadded_lse? bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb: (bidb * params.h + bidh) * params.seqlen_q_rounded) + (m_block_max - 1) * kBlockM;
 
     Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
                             Shape<Int<kBlockM>, Int<kHeadDim>>{},

+ 2 - 1
csrc/flash_attn/src/flash_bwd_preprocess_kernel.h

@@ -79,7 +79,8 @@ inline __device__ void compute_dot_do_o(const Params &params) {
         + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
     const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
         + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
-    const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
+    // Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d
+    const index_t row_offset_dpsum = (params.unpadded_lse ? (bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb): (bidb * params.h + bidh) * params.seqlen_q_rounded) + m_block * kBlockM;
 
     Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
                              Shape<Int<kBlockM>, Int<kHeadDim>>{},

+ 55 - 13
csrc/flash_attn/src/flash_fwd_kernel.h

@@ -24,6 +24,27 @@ using namespace cute;
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
+template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
+__forceinline__ __device__ auto get_lse_tile(const Params &params, const int bidb, const int bidh, const int m_block, const BlockInfo</*Varlen=*/!Is_even_MN> &binfo) {
+        // When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path.
+        // Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick.
+        // Otherwise, it's written as (h, b, seqlen_q).
+        const bool varlen_q = params.unpadded_lse && !params.seqlenq_ngroups_swapped;
+        auto lse_offset = varlen_q ? binfo.q_offset(params.seqlen_q, 1, bidb) : 0;
+        auto gmem_ptr_lse = make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) + lse_offset);
+
+        auto lse_shape = varlen_q ? make_shape(1, params.h, params.total_q) : make_shape(params.b, params.h, params.seqlen_q);
+        auto lse_stride = params.seqlenq_ngroups_swapped ? make_stride(1, params.seqlen_q * params.b, params.b) : (
+            params.unpadded_lse ? make_stride(params.h * params.total_q, params.total_q, 1) :  make_stride(params.h * params.seqlen_q, params.seqlen_q, 1)
+            );
+
+        auto lse_layout = make_layout(lse_shape, lse_stride);
+        Tensor mLSE = make_tensor(gmem_ptr_lse, lse_layout);
+        auto mLSE_slice = varlen_q ? mLSE(0, bidh, _) : mLSE(bidb, bidh, _);
+        return local_tile(mLSE_slice, Shape<Int<kBlockM>>{}, make_coord(m_block));
+}
+
+
 template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
 inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
 
@@ -74,10 +95,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
                                 make_stride(params.o_row_stride, params.o_head_stride, _1{}));
         Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
                               make_coord(m_block, 0));  // (kBlockM, kHeadDim)
-        Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
-                                  make_shape(params.b, params.h, params.seqlen_q),
-                                  make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
-        Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
+
+        Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);
 
         typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
         auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
@@ -424,10 +443,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
                             make_stride(params.o_row_stride, params.o_head_stride, _1{}));
     Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
                            make_coord(m_block, 0));  // (kBlockM, kHeadDim)
-    Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
-                              make_shape(params.b, params.h, params.seqlen_q),
-                              make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{}));
-    Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape<Int<kBlockM>>{}, make_coord(m_block));
+    Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);
 
     typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
     auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
@@ -986,7 +1002,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
         + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
     const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
                                          + m_block * kBlockM) * params.d_rounded;
-    const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
+    const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ?
+            ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q : bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb)
+        ) + m_block * kBlockM;
 
     Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
                                  Shape<Int<kBlockM>, Int<kHeadDim>>{},
@@ -1092,21 +1110,36 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
     const int tidx = threadIdx.x;
     const int bidx = blockIdx.x;
 
+    const index_t lse_size = params.b * params.h * params.seqlen_q;
+
     const index_t row_offset_lse = bidx * kBlockM;
     Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lse),
                                    Shape<Int<kMaxSplits>, Int<kBlockM>>{},
-                                   make_stride(params.b * params.h * params.seqlen_q, _1{}));
+                                   make_stride(lse_size, _1{}));
+
+    // LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile.
+    // This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}.
     Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
                               Shape<Int<kBlockM>>{}, Stride<_1>{});
+
+    // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}.
+    Layout flat_layout = make_layout(lse_size);
+    Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b));
+    auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q);
+    Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride);
+    Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout));
+
+    Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr)), final_layout);
+
     constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads;
 
-    // Read the LSE values from gmem and store them in shared memory, then tranpose them.
+    // Read the LSE values from gmem and store them in shared memory, then transpose them.
     constexpr int kRowsPerLoadLSE = kNThreads / kBlockM;
     #pragma unroll
     for (int l = 0; l < kNLsePerThread; ++l) {
         const int row = l * kRowsPerLoadLSE + tidx / kBlockM;
         const int col = tidx % kBlockM;
-        ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY;
+        ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY;
         if (row < kMaxSplits) { sLSE[row][col] = lse; }
         // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); }
     }
@@ -1145,7 +1178,16 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
     // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
     ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max;
     // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
-    if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; }
+    if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) {
+        if (params.unpadded_lse) {
+            const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose;
+            if (lse_offset < lse_size) {
+                gLSE_unpadded(lse_offset) = lse_logsum;
+            }
+        } else {
+            gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum;
+        }
+    }
     // Store the scales exp(lse - lse_logsum) in shared memory.
     #pragma unroll
     for (int l = 0; l < kNLsePerThread; ++l) {

+ 15 - 5
flash_attn/flash_attn_interface.py

@@ -130,7 +130,12 @@ def _flash_attn_backward(
     maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
     # dq, dk, dv are allocated by us so they should already be contiguous
     dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
-    dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
+    (
+        dq,
+        dk,
+        dv,
+        softmax_d,
+    ) = flash_attn_cuda.bwd(
         dout,
         q,
         k,
@@ -178,7 +183,12 @@ def _flash_attn_varlen_backward(
     maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
     # dq, dk, dv are allocated by us so they should already be contiguous
     dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
-    dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
+    (
+        dq,
+        dk,
+        dv,
+        softmax_d,
+    ) = flash_attn_cuda.varlen_bwd(
         dout,
         q,
         k,
@@ -883,7 +893,7 @@ def flash_attn_varlen_qkvpacked_func(
            (they might not have the right scaling).
     Return:
         out: (total, nheads, headdim).
-        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
+        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
             logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
             normalization factor).
         S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
@@ -968,7 +978,7 @@ def flash_attn_varlen_kvpacked_func(
            (they might not have the right scaling).
     Return:
         out: (total, nheads, headdim).
-        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
+        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
             logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
             normalization factor).
         S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
@@ -1056,7 +1066,7 @@ def flash_attn_varlen_func(
            (they might not have the right scaling).
     Return:
         out: (total, nheads, headdim).
-        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
+        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
             logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
             normalization factor).
         S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).

+ 6 - 3
tests/test_flash_attn.py

@@ -1151,6 +1151,7 @@ def test_flash_attn_varlen_output(
     assert nheads % nheads_k == 0
     window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
     q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
+
     if kvpacked:
         kv = torch.randn(
             batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
@@ -1918,9 +1919,11 @@ def test_flash_attn_kvcache(
     cache_seqlens = torch.randint(
         0 if new_kv else 1,
         # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
-        (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)
-        if new_kv
-        else (seqlen_k + 1),
+        (
+            (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)
+            if new_kv
+            else (seqlen_k + 1)
+        ),
         (batch_size,),
         dtype=torch.int32,
         device=device,