1
0
Эх сурвалжийг харах

Add seqused_q in fwd / bwd and seqused_k in bwd.

Ying Zhang 6 сар өмнө
parent
commit
496fdc4f6c

+ 14 - 6
flash_attn/bert_padding.py

@@ -95,19 +95,23 @@ class IndexFirstAxisResidual(torch.autograd.Function):
 index_first_axis_residual = IndexFirstAxisResidual.apply
 
 
-def unpad_input(hidden_states, attention_mask):
+def unpad_input(hidden_states, attention_mask, unused_mask=None):
     """
     Arguments:
         hidden_states: (batch, seqlen, ...)
         attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
+        unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
     Return:
-        hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
-        indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
+        hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
+        indices: (used_nnz), the indices of non-masked tokens from the flattened input sequence.
         cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
         max_seqlen_in_batch: int
+        seqused: (batch), optionally returns the number of tokens selected in attention_mask + unused_mask if unused_mask is not None.
     """
-    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
-    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+    all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
+    seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
+    used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+    indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
     max_seqlen_in_batch = seqlens_in_batch.max().item()
     cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
     # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
@@ -115,12 +119,16 @@ def unpad_input(hidden_states, attention_mask):
     # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
     # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
     # so we write custom forward and backward to make it a bit faster.
-    return (
+    res = (
         index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
         indices,
         cu_seqlens,
         max_seqlen_in_batch,
     )
+    if unused_mask is not None:
+        return res + (used_seqlens_in_batch, )
+    else:
+        return res
 
 
 def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):

+ 7 - 3
hopper/epilogue_bwd_sm90_tma.hpp

@@ -80,6 +80,7 @@ struct CollectiveEpilogueBwd {
         Element* ptr_dV;
         StridedKV const stride_dV;
         int const* cu_seqlens = nullptr;
+        int const* seqused = nullptr;
     };
 
     // Device side kernel params
@@ -91,6 +92,7 @@ struct CollectiveEpilogueBwd {
         StridedKV const stride_dV;
         TMA_dKV tma_store_dK, tma_store_dV;
         int const* cu_seqlens = nullptr;
+        int const* seqused = nullptr;
     };
 
     static Params
@@ -113,7 +115,7 @@ struct CollectiveEpilogueBwd {
             select<1, 2>(TileShape_MNK{}),
             _1{}); // no mcast for dKV
         return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.stride_dV,
-                tma_store_dK, tma_store_dV, args.cu_seqlens};
+                tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused};
     }
 
     /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
@@ -185,7 +187,9 @@ struct CollectiveEpilogueBwd {
             cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
             bool const is_varlen = params.cu_seqlens != nullptr;
             int const offset = !is_varlen ? 0 : params.cu_seqlens[bidb];
-            int const seqlen = !is_varlen ? get<0>(params.shape_dK) : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb];
+            int const seqlen = !is_varlen ? get<0>(params.shape_dK) : (
+                params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb]
+            );
 
             Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
             Tensor gdK = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{}));  // (M, K)
@@ -236,7 +240,7 @@ struct CollectiveEpilogueBwd {
         auto [n_block, bidh, bidb] = block_coord;
         bool const is_varlen = Varlen && params.cu_seqlens != nullptr;
         int const offset = !is_varlen ? 0 : params.cu_seqlens[bidb];
-        int const seqlen = !is_varlen ? get<0>(params.shape_dK) : params.cu_seqlens[bidb + 1] - offset;
+        int const seqlen = !is_varlen ? get<0>(params.shape_dK) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - offset);
 
         Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
         Tensor gdK = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{}));  // (M, K)

+ 3 - 1
hopper/flash.h

@@ -68,7 +68,9 @@ struct Flash_fwd_params : public Qkv_params {
     int * __restrict__ cu_seqlens_q;
     int * __restrict__ cu_seqlens_k;
 
-    // If provided, the actual length of each k sequence.
+    // If provided, the actual length of each q / o sequence.
+    int * __restrict__ seqused_q;
+    // If provided, the actual length of each k / v sequence.
     int * __restrict__ seqused_k;
 
     int *__restrict__ blockmask;

+ 40 - 3
hopper/flash_api.cpp

@@ -36,6 +36,7 @@ void set_params_fprop(Flash_fwd_params &params,
                       at::Tensor out,
                       void *cu_seqlens_q_d,
                       void *cu_seqlens_k_d,
+                      void *seqused_q,
                       void *seqused_k,
                       void *p_d,
                       void *softmax_lse_d,
@@ -80,6 +81,7 @@ void set_params_fprop(Flash_fwd_params &params,
 
     params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
     params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
+    params.seqused_q = static_cast<int *>(seqused_q);
     params.seqused_k = static_cast<int *>(seqused_k);
 
     TORCH_CHECK(
@@ -171,6 +173,8 @@ void set_params_dgrad(Flash_bwd_params &params,
                       at::Tensor dv,
                       void *cu_seqlens_q_d,
                       void *cu_seqlens_k_d,
+                      void *seqused_q,
+                      void *seqused_k,
                       void *dq_accum_d,
                       void *dk_accum_d,
                       void *dv_accum_d,
@@ -187,7 +191,8 @@ void set_params_dgrad(Flash_bwd_params &params,
                      q, k, v, out,
                      cu_seqlens_q_d,
                      cu_seqlens_k_d,
-                     nullptr,
+                     seqused_q,
+                     seqused_k,
                      nullptr,
                      softmax_lse_d,
                      p_dropout,
@@ -364,6 +369,7 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
                      q_padded, k_padded, v_padded, out,
                      /*cu_seqlens_q_d=*/nullptr,
                      /*cu_seqlens_k_d=*/nullptr,
+                     /*seqused_q=*/nullptr,
                      /*seqused_k=*/nullptr,
                      nullptr,
                      softmax_lse.data_ptr(),
@@ -426,6 +432,7 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
                c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
                const at::Tensor &cu_seqlens_q,  // b+1
                const at::Tensor &cu_seqlens_k,  // b+1
+               c10::optional<at::Tensor> &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used.
                c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
                int max_seqlen_q,
                const int max_seqlen_k,
@@ -482,6 +489,14 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
     CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
 
     CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
+    if (seqused_q.has_value()){
+        auto seqused_q_ = seqused_q.value();
+        TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32");
+        TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device");
+        TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous");
+        CHECK_SHAPE(seqused_q_, batch_size);
+    }
+
     CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
     if (seqused_k.has_value()){
         auto seqused_k_ = seqused_k.value();
@@ -537,6 +552,7 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
                      q_padded, k_padded, v_padded, out,
                      cu_seqlens_q_d,
                      cu_seqlens_k.data_ptr(),
+                     seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr,
                      seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
                      /*p_d=*/nullptr,
                      softmax_lse.data_ptr(),
@@ -730,8 +746,10 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
                      head_size, head_size_rounded,
                      q, k, v, out,
                      dout_padded, dq, dk_expanded, dv_expanded,
-                     nullptr,
-                     nullptr,
+                     /*cu_seqlens_q_d=*/nullptr,
+                     /*cu_seqlens_k_d=*/nullptr,
+                     /*seqused_q=*/nullptr,
+                     /*seqused_k=*/nullptr,
                      dq_accum.data_ptr(),
                      // loop ? dk_accum.data_ptr() : nullptr,
                      // loop ? dv_accum.data_ptr() : nullptr,
@@ -787,6 +805,8 @@ mha_varlen_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x
                c10::optional<at::Tensor> &dv_,   // batch_size x seqlen_k x num_heads_k x head_size
                const at::Tensor &cu_seqlens_q,  // b+1
                const at::Tensor &cu_seqlens_k,  // b+1
+               c10::optional<at::Tensor> &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used.
+               c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
                const int max_seqlen_q,
                const int max_seqlen_k,          // max sequence length to choose the kernel
                const float softmax_scale,
@@ -854,7 +874,22 @@ mha_varlen_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x
     CHECK_SHAPE(out, total_q, num_heads, head_size);
     CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
     CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
+    if (seqused_q.has_value()){
+        auto seqused_q_ = seqused_q.value();
+        TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32");
+        TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device");
+        TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous");
+        CHECK_SHAPE(seqused_q_, batch_size);
+    }
+
     CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
+    if (seqused_k.has_value()){
+        auto seqused_k_ = seqused_k.value();
+        TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
+        TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
+        TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
+        CHECK_SHAPE(seqused_k_, batch_size);
+    }
 
     at::Tensor dq, dk, dv;
     if (dq_.has_value()) {
@@ -927,6 +962,8 @@ mha_varlen_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x
                      dout_padded, dq, dk_expanded, dv_expanded,
                      cu_seqlens_q.data_ptr(),
                      cu_seqlens_k.data_ptr(),
+                     seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr,
+                     seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
                      dq_accum.data_ptr(),
                      // loop ? dk_accum.data_ptr() : nullptr,
                      // loop ? dv_accum.data_ptr() : nullptr,

+ 26 - 4
hopper/flash_attn_interface.py

@@ -72,6 +72,8 @@ def _flash_attn_varlen_forward(
     max_seqlen_k,
     softmax_scale,
     causal,
+    seqused_q=None,
+    seqused_k=None,
 ):
     maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
     q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
@@ -82,7 +84,8 @@ def _flash_attn_varlen_forward(
         None,
         cu_seqlens_q,
         cu_seqlens_k,
-        None,
+        seqused_q,
+        seqused_k,
         max_seqlen_q,
         max_seqlen_k,
         softmax_scale,
@@ -110,6 +113,8 @@ def _flash_attn_varlen_backward(
     softmax_scale,
     causal,
     deterministic=False,
+    seqused_q=None,
+    seqused_k=None,
 ):
     maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
     # dq, dk, dv are allocated by us so they should already be contiguous
@@ -132,6 +137,8 @@ def _flash_attn_varlen_backward(
         dv,
         cu_seqlens_q,
         cu_seqlens_k,
+        seqused_q,
+        seqused_k,
         max_seqlen_q,
         max_seqlen_k,
         softmax_scale,
@@ -207,6 +214,8 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
         softmax_scale,
         causal,
         deterministic=False,
+        seqused_q=None,
+        seqused_k=None,
     ):
         if softmax_scale is None:
             softmax_scale = q.shape[-1] ** (-0.5)
@@ -220,9 +229,12 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
             max_seqlen_k,
             softmax_scale,
             causal=causal,
+            seqused_q=seqused_q,
+            seqused_k=seqused_k,
         )
         ctx.save_for_backward(
-            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k
+            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k,
+            seqused_q, seqused_k
         )
         ctx.max_seqlen_q = max_seqlen_q
         ctx.max_seqlen_k = max_seqlen_k
@@ -233,7 +245,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
 
     @staticmethod
     def backward(ctx, dout, *args):
-        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
+        q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
         dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
         _flash_attn_varlen_backward(
             dout,
@@ -252,11 +264,13 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
             ctx.softmax_scale,
             ctx.causal,
             ctx.deterministic,
+            seqused_q,
+            seqused_k,
         )
         dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
         dk = dk[..., : dout.shape[-1]]
         dv = dv[..., : dout.shape[-1]]
-        return dq, dk, dv, None, None, None, None, None, None, None
+        return dq, dk, dv, None, None, None, None, None, None, None, None, None
 
 
 def flash_attn_func(
@@ -336,6 +350,8 @@ def flash_attn_varlen_func(
     softmax_scale=None,
     causal=False,
     deterministic=False,
+    seqused_q=None,
+    seqused_k=None,
 ):
     """
     Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
@@ -366,6 +382,10 @@ def flash_attn_varlen_func(
         softmax_scale: float. The scaling of QK^T before applying softmax.
             Default to 1 / sqrt(headdim).
         causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
+        seqused_q: (batch_size,), dtype torch.int32. If not None, it defines the actual number of 
+            query and output tokens in each sequence.
+        seqused_k: (batch_size,), dtype torch.int32. If not None, it defines the actual number of 
+            key and value tokens in each sequence.
     Return:
         out: (total, nheads, headdim).
         softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
@@ -383,4 +403,6 @@ def flash_attn_varlen_func(
         softmax_scale,
         causal,
         deterministic,
+        seqused_q,
+        seqused_k,
     )

+ 5 - 2
hopper/flash_bwd_launch_template.h

@@ -45,7 +45,8 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
         {params.d_rounded, _1{}, params.d_rounded * (!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded), !Varlen ? params.d_rounded * params.seqlen_q_rounded * params.h : 0},  // stride_dQ
         params.b,
         params.dq_semaphore,
-        params.cu_seqlens_q
+        params.cu_seqlens_q,
+        params.seqused_q
     };
     typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args);
     int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM);
@@ -87,6 +88,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
         params.b,
         params.dq_semaphore,
         params.cu_seqlens_q, params.cu_seqlens_k,
+        params.seqused_q, params.seqused_k
     };
     typename CollectiveEpilogue::Arguments epilogue_args {
         static_cast<Element*>(params.dk_ptr),
@@ -146,7 +148,8 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
         {!Varlen ? params.seqlen_q : params.total_q, params.d, params.h, !Varlen ? params.b : 1},  // shape_dQ
         {params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride},  // stride_dQ
         params.scale_softmax,
-        params.cu_seqlens_q
+        params.cu_seqlens_q,
+        params.seqused_q
     };
     typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args);
     int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{}));

+ 5 - 2
hopper/flash_bwd_postprocess_kernel.h

@@ -102,6 +102,7 @@ public:
         StridedQ const stride_dQ;
         float const softmax_scale;
         int const* cu_seqlens = nullptr;
+        int const* seqused = nullptr;
     };
 
     // Kernel entry point API
@@ -113,6 +114,7 @@ public:
         StridedQ const stride_dQ;
         float const softmax_scale;
         int const* cu_seqlens = nullptr;
+        int const* seqused = nullptr;
     };
 
     // Convert to underlying arguments. In this case, a simple copy for the aliased type.
@@ -133,7 +135,8 @@ public:
             args.shape_dQ,
             args.stride_dQ,
             args.softmax_scale,
-            args.cu_seqlens
+            args.cu_seqlens,
+            args.seqused
         };
     }
 
@@ -156,7 +159,7 @@ public:
         int const bidb = blockIdx.z;
 
         bool const is_varlen = params.cu_seqlens != nullptr;
-        int const seqlen = !is_varlen ? get<0>(params.shape_dQ) : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb];
+        int const seqlen = !is_varlen ? get<0>(params.shape_dQ) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb]);
         if (is_varlen && m_block * kBlockM >= seqlen) { return; }
 
         int lane_predicate = cute::elect_one_sync();

+ 5 - 2
hopper/flash_bwd_preprocess_kernel.h

@@ -85,6 +85,7 @@ public:
         int num_batch;  // We need this to know the size of dq_semaphore in case of varlen
         int* dq_semaphore;
         int const* cu_seqlens = nullptr;
+        int const* seqused = nullptr;
     };
 
     // Kernel entry point API
@@ -107,6 +108,7 @@ public:
         int num_batch;
         int* dq_semaphore;
         int const* cu_seqlens = nullptr;
+        int const* seqused = nullptr;
     };
 
     // Convert to underlying arguments. In this case, a simple copy for the aliased type.
@@ -131,7 +133,8 @@ public:
             args.stride_dQaccum,
             args.num_batch,
             args.dq_semaphore,
-            args.cu_seqlens
+            args.cu_seqlens,
+            args.seqused
         };
     }
 
@@ -148,7 +151,7 @@ public:
 
         bool const is_varlen = Varlen && params.cu_seqlens != nullptr;
         int const offset_o = !is_varlen ? 0 : params.cu_seqlens[bidb];
-        int const seqlen_o = !is_varlen ? get<0>(params.shape_O) : params.cu_seqlens[bidb + 1] - offset_o;
+        int const seqlen_o = !is_varlen ? get<0>(params.shape_O) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - offset_o);
         if (is_varlen && m_block * kBlockM >= seqlen_o) { return; }
 
         Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !is_varlen ? bidb : 0);

+ 1 - 1
hopper/flash_fwd_launch_template.h

@@ -37,7 +37,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
     >>;
     // using Scheduler = flash::SingleTileScheduler;
     Seqlen_traits seqlen_traits_q(
-        params.total_q, params.seqlen_q, params.cu_seqlens_q);
+        params.total_q, params.seqlen_q, params.cu_seqlens_q, params.seqused_q);
     Seqlen_traits seqlen_traits_k(
         params.total_k, params.seqlen_k, params.cu_seqlens_k, params.seqused_k);
     typename CollectiveMainloop::Params mainloop_params =

+ 14 - 4
hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp

@@ -279,6 +279,8 @@ struct CollectiveMainloopBwd {
         int* dq_semaphore;
         int const* cu_seqlens_q = nullptr;
         int const* cu_seqlens_k = nullptr;
+        int const* seqused_k = nullptr;
+        int const* seqused_v = nullptr;
     };
 
     // Device side kernel params
@@ -303,6 +305,8 @@ struct CollectiveMainloopBwd {
         int* dq_semaphore;
         int const* cu_seqlens_q = nullptr;
         int const* cu_seqlens_k = nullptr;
+        int const* seqused_q = nullptr;
+        int const* seqused_k = nullptr;
     };
 
     static Params
@@ -362,7 +366,8 @@ struct CollectiveMainloopBwd {
                 tma_load_Q, tma_load_dO, tma_load_K, tma_load_V, tma_add_dQ, tma_load_LSE, tma_load_dPsum,
                 args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum,
                 args.softmax_scale, float(args.softmax_scale * M_LOG2E),
-                args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k};
+                args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k,
+                args.seqused_k, args.seqused_v};
     }
 
     /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
@@ -384,7 +389,10 @@ struct CollectiveMainloopBwd {
         } else {
             return params.cu_seqlens_q == nullptr
                 ? get<0>(params.shape_Q)
-                : params.cu_seqlens_q[bidb + 1] - params.cu_seqlens_q[bidb];
+                : (params.seqused_q
+                    ? params.seqused_q[bidb]
+                    : params.cu_seqlens_q[bidb + 1] - params.cu_seqlens_q[bidb]
+                );
         }
     }
 
@@ -395,7 +403,10 @@ struct CollectiveMainloopBwd {
         } else {
             return params.cu_seqlens_k == nullptr
                 ? get<0>(params.shape_K)
-                : params.cu_seqlens_k[bidb + 1] - params.cu_seqlens_k[bidb];
+                : (params.seqused_k
+                    ? params.seqused_k[bidb]
+                    : params.cu_seqlens_k[bidb + 1] - params.cu_seqlens_k[bidb]
+                );
         }
     }
 
@@ -838,4 +849,3 @@ struct CollectiveMainloopBwd {
 };
 
 } // namespace flash
-

+ 28 - 3
hopper/test_flash_attn.py

@@ -45,7 +45,7 @@ def print_diffs(out, out_ref):
     "seqlen_q,seqlen_k",
     [
         (1, 1),
-        (257, 1),
+        # (257, 1),
         (64, 128),
         (128, 128),
         (256, 256),
@@ -199,6 +199,8 @@ def test_flash_attn_output(
 # @pytest.mark.parametrize("causal", [False])
 @pytest.mark.parametrize("deterministic", [False, True])
 # @pytest.mark.parametrize("deterministic", [False])
+# @pytest.mark.parametrize("add_unused_qkv", [False, True])
+@pytest.mark.parametrize("add_unused_qkv", [True])
 # @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
 # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
 # @pytest.mark.parametrize('d', [128])
@@ -231,7 +233,7 @@ def test_flash_attn_output(
 )
 # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
 def test_flash_attn_varlen_output(
-    seqlen_q, seqlen_k, d, causal, deterministic, mha_type, dtype
+    seqlen_q, seqlen_k, d, causal, deterministic, add_unused_qkv, mha_type, dtype
 ):
     if (
         max(seqlen_q, seqlen_k) >= 2048
@@ -259,12 +261,27 @@ def test_flash_attn_varlen_output(
     key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random", zero_lengths=True)
     # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
 
+    def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
+        if add_unused:
+            another_mask = generate_random_padding_mask(max_seq_len, bs, device)
+            attn_mask = torch.logical_and(padding_mask, another_mask)
+            unused_mask = torch.logical_xor(torch.logical_or(padding_mask, another_mask), attn_mask)
+        else:
+            attn_mask = padding_mask
+            unused_mask = None
+        return attn_mask, unused_mask
+
+    query_padding_mask, query_unused_mask = _gen_unused_masks(query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device)
+    key_padding_mask, key_unused_mask = _gen_unused_masks(key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device)
+
     (
         q_unpad,
         k_unpad,
         v_unpad,
         cu_seqlens_q,
         cu_seqlens_k,
+        seqused_q,
+        seqused_k,
         max_seqlen_q,
         max_seqlen_k,
         q,
@@ -273,7 +290,7 @@ def test_flash_attn_varlen_output(
         output_pad_fn,
         dq_pad_fn,
         dk_pad_fn,
-    ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
+    ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False, query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask)
     # print("cu_seqlens_q: ", cu_seqlens_q)
     # print("cu_seqlens_k: ", cu_seqlens_k)
     # print("q_unpad, shape: ", q_unpad.shape)
@@ -289,8 +306,12 @@ def test_flash_attn_varlen_output(
         max_seqlen_k,
         causal=causal,
         deterministic=deterministic,
+        seqused_q=seqused_q,
+        seqused_k=seqused_k,
     )
     out = output_pad_fn(out_unpad)
+    q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1")
+    out.masked_fill_(q_zero_masking, 0.0)
     dropout_mask = None
 
     out_ref, attn_ref = attention_ref(
@@ -326,6 +347,9 @@ def test_flash_attn_varlen_output(
         ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
         dk = dk_pad_fn(dk_unpad)
         dv = dk_pad_fn(dv_unpad)
+        k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1")
+        dk.masked_fill_(k_zero_masking, 0.0)
+        dv.masked_fill_(k_zero_masking, 0.0)
         (
             dq_ref,
             dk_ref,
@@ -342,6 +366,7 @@ def test_flash_attn_varlen_output(
         dk_pt.masked_fill_(zero_masking, 0.0)
         dv_pt.masked_fill_(zero_masking, 0.0)
         dq = dq_pad_fn(dq_unpad)
+        dq.masked_fill_(q_zero_masking, 0.0)
         print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
         print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
         print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")

+ 15 - 4
tests/test_util.py

@@ -29,7 +29,9 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random",
 
 
 def generate_qkv(
-    q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False
+    q, k, v, query_padding_mask=None, key_padding_mask=None, 
+    kvpacked=False, qkvpacked=False, add_unused_qkv=False,
+    query_unused_mask=None, key_unused_mask=None,
 ):
     """
     Arguments:
@@ -44,9 +46,14 @@ def generate_qkv(
     _, seqlen_k, nheads_k, _ = k.shape
     assert k.shape == (batch_size, seqlen_k, nheads_k, d)
     assert v.shape == (batch_size, seqlen_k, nheads_k, d)
+    if query_unused_mask is not None or key_unused_mask is not None:
+        assert not kvpacked
+        assert not qkvpacked
 
     if query_padding_mask is not None:
-        q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
+        q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q  = unpad_input(
+            q, query_padding_mask, query_unused_mask,
+        )
         output_pad_fn = lambda output_unpad: pad_input(
             output_unpad, indices_q, batch_size, seqlen_q
         )
@@ -55,20 +62,22 @@ def generate_qkv(
         cu_seqlens_q = torch.arange(
             0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
         )
+        seqused_q = None
         max_seqlen_q = seqlen_q
         output_pad_fn = lambda output_unpad: rearrange(
             output_unpad, "(b s) h d -> b s h d", b=batch_size
         )
 
     if key_padding_mask is not None:
-        k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
-        v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
+        k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(k, key_padding_mask, key_unused_mask)
+        v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask, key_unused_mask)
     else:
         k_unpad = rearrange(k, "b s h d -> (b s) h d")
         v_unpad = rearrange(v, "b s h d -> (b s) h d")
         cu_seqlens_k = torch.arange(
             0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
         )
+        seqused_k = None
         max_seqlen_k = seqlen_k
 
     if qkvpacked:
@@ -125,6 +134,8 @@ def generate_qkv(
             v_unpad.detach().requires_grad_(),
             cu_seqlens_q,
             cu_seqlens_k,
+            seqused_q,
+            seqused_k,
             max_seqlen_q,
             max_seqlen_k,
             q.detach().requires_grad_(),