Procházet zdrojové kódy

Implement deterministic backward (thanks to Meituan)

Tri Dao před 1 rokem
rodič
revize
732654583c

+ 9 - 3
README.md

@@ -83,7 +83,7 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
 
 ```python
 flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
-                          window_size=(-1, -1), alibi_slopes=None):
+                          window_size=(-1, -1), alibi_slopes=None, deterministic=False):
 """dropout_p should be set to 0.0 during evaluation
 If Q, K, V are already stacked into 1 tensor, this function will be faster than
 calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
@@ -99,6 +99,8 @@ Arguments:
     window_size: (left, right). If not (-1, -1), implements sliding window local attention.
     alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
         the attention score of query i and key j.
+    deterministic: bool. Whether to use the deterministic implementation of the backward pass,
+        which is slightly slower and uses more memory. The forward pass is always deterministic.
 Return:
     out: (batch_size, seqlen, nheads, headdim).
 """
@@ -106,7 +108,7 @@ Return:
 
 ```python
 flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
-                window_size=(-1, -1), alibi_slopes=None):
+                window_size=(-1, -1), alibi_slopes=None, deterministic=False):
 """dropout_p should be set to 0.0 during evaluation
 Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
 than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
@@ -128,6 +130,8 @@ Arguments:
     alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
         (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
         is added to the attention score of query i and key j.
+    deterministic: bool. Whether to use the deterministic implementation of the backward pass,
+        which is slightly slower and uses more memory. The forward pass is always deterministic.
 Return:
     out: (batch_size, seqlen, nheads, headdim).
 """
@@ -269,10 +273,12 @@ Implement sliding window attention (i.e., local attention). Thanks to [Mistral
 AI](https://mistral.ai/) and in particular Timothée Lacroix for this
 contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model.
 
-### 2.4: ALiBi (attention with linear bias)
+### 2.4: ALiBi (attention with linear bias), deterministic backward pass.
 
 Implement ALiBi (Press et el., 2021). Thanks to Sanghun Cho from Kakao Brain for this contribution.
 
+Implement deterministic backward pass. Thanks to engineers from [Meituan](www.meituan.com) for this contribution.
+
 ## Performance
 
 We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).

+ 42 - 10
csrc/flash_attn/flash_api.cpp

@@ -150,7 +150,8 @@ void set_params_dgrad(Flash_bwd_params &params,
                       float p_dropout,
                       float softmax_scale,
                       int window_size_left,
-                      int window_size_right) {
+                      int window_size_right,
+                      bool deterministic) {
 
     set_params_fprop(params,
                      b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
@@ -192,6 +193,8 @@ void set_params_dgrad(Flash_bwd_params &params,
 
     // Softmax sum
     params.dsoftmax_sum = dsoftmax_sum_d;
+
+    params.deterministic = deterministic;
 }
 
 void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
@@ -618,8 +621,14 @@ mha_varlen_fwd(const at::Tensor &q,  // total_q x num_heads x head_size, total_q
         params.alibi_slopes_ptr = nullptr;
     }
 
-    auto stream = at::cuda::getCurrentCUDAStream().stream();
-    run_mha_fwd(params, stream);
+    if (max_seqlen_k > 0) {
+        auto stream = at::cuda::getCurrentCUDAStream().stream();
+        run_mha_fwd(params, stream);
+    } else {
+        // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
+        out.zero_();
+        softmax_lse.fill_(std::numeric_limits<float>::infinity());
+    }
 
     at::Tensor out_padded = out;
     if (head_size_og % 8 != 0) {
@@ -668,6 +677,7 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
         const bool is_causal,
         const int window_size_left,
         int window_size_right,
+        const bool deterministic,
         c10::optional<at::Generator> gen_,
         c10::optional<at::Tensor> &rng_state) {
 
@@ -783,7 +793,12 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
     at::Tensor dq_accum;
     at::Tensor dk_accum, dv_accum;
     if (loop) {
-        dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
+        if (!deterministic) {
+            dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
+        } else {
+            const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
+            dq_accum = torch::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
+        }
         // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
         // dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
     }
@@ -819,7 +834,9 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
                      p_dropout,
                      softmax_scale,
                      window_size_left,
-                     window_size_right);
+                     window_size_right,
+                     deterministic);
+    params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
 
     auto launch = &run_mha_bwd;
     // launch(params, stream, /*configure=*/true);
@@ -857,8 +874,8 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
         launch(params, stream, /*configure=*/false);
     } else {
         // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
-        dk.zero_();
-        dv.zero_();
+        dk_expanded.zero_();
+        dv_expanded.zero_();
         softmax_d.zero_();
     }
 
@@ -897,6 +914,7 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
                const bool is_causal,
                const int window_size_left,
                int window_size_right,
+               const bool deterministic,
                c10::optional<at::Generator> gen_,
                c10::optional<at::Tensor> &rng_state) {
 
@@ -1025,7 +1043,12 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
         // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
         // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
         // allowed to do. So we won't have to do any bound checking, and performance should stay the same.
-        dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
+        if (!deterministic) {
+            dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
+        } else {
+            const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
+            dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
+        }
     }
 
     at::Tensor dk_expanded, dv_expanded;
@@ -1064,7 +1087,9 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
                      p_dropout,
                      softmax_scale,
                      window_size_left,
-                     window_size_right);
+                     window_size_right,
+                     deterministic);
+    params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
 
     auto launch = &run_mha_bwd;
     // launch(params, stream, /*configure=*/true);
@@ -1098,7 +1123,14 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
         params.alibi_slopes_ptr = nullptr;
     }
 
-    launch(params, stream, /*configure=*/false);
+    if (max_seqlen_q > 0) {
+        launch(params, stream, /*configure=*/false);
+    } else {
+        // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
+        dk_expanded.zero_();
+        dv_expanded.zero_();
+        softmax_d.zero_();
+    }
 
     // For MQA/GQA we need to sum dK and dV across the groups
     if (num_heads_k != num_heads) {

+ 3 - 0
csrc/flash_attn/src/flash.h

@@ -172,6 +172,9 @@ struct Flash_bwd_params : public Flash_fwd_params {
 
     // The pointer to the softmax d sum.
     void *__restrict__ dsoftmax_sum;
+
+    bool deterministic;
+    index_t dq_accum_split_stride;
 };
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////

+ 17 - 9
csrc/flash_attn/src/flash_bwd_kernel.h

@@ -230,7 +230,7 @@ inline __device__ void clear_dKVaccum(const Params &params) {
 // Convert dQ from dQaccum (in float) to fp16/bf16.
 // This is used in the case where we want to parallelize the backward across seqlen_k.
 template<typename Kernel_traits, typename Params>
-inline __device__ void convert_dQ(const Params &params) {
+inline __device__ void convert_dQ(const Params &params, const int nsplits) {
     using Element = typename Kernel_traits::Element;
     using ElementAccum = typename Kernel_traits::ElementAccum;
     using index_t = typename Kernel_traits::index_t;
@@ -285,11 +285,15 @@ inline __device__ void convert_dQ(const Params &params) {
     CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum));
 
     Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum);
-    cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum);
-    #pragma unroll
-    for (int i = 0; i < size(acc_dq); ++i) {
-        acc_dq(i) = tdQrdQaccum(i) * params.scale_softmax_rp_dropout;
+    clear(acc_dq);
+    for (int s = 0; s < nsplits; ++s) {
+        cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum);
+        #pragma unroll
+        for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); }
+        tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride;
     }
+    #pragma unroll
+    for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }
     // Convert acc_dq from fp32 to fp16
     Tensor rdQ = flash::convert_type<Element>(acc_dq);
     Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ);  // ((Atom,AtomNum), MMA_N, MMA_N)
@@ -466,7 +470,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
     const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
         + (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
     const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
-        + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
+        + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded
+        // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
+        + (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride);
     const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q
         + (m_block_max - 1) * kBlockM;
     const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded
@@ -715,7 +721,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
         tdKsQt.data() = tdKsQt.data() + size(sQ);
     }
 
-    if (!Is_first && !Seq_parallel) { __syncthreads(); }
+    if ((!Is_first && !Seq_parallel) || params.deterministic) { __syncthreads(); }
 
     if (Kernel_traits::Is_V_in_regs) {
         // Clear the smem tiles to account for predicated off loads
@@ -1604,13 +1610,15 @@ inline __device__ void compute_dq_dk_dv(const Params &params) {
 template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, typename Params>
 inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
 
-    const int n_block = blockIdx.x;
     // The block index for the batch.
     const int bidb = blockIdx.y;
     // The block index for the head.
     const int bidh = blockIdx.z;
 
-    compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
+    // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
+    for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) {
+        compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
+    }
 }
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////

+ 15 - 5
csrc/flash_attn/src/flash_bwd_launch_template.h

@@ -35,8 +35,8 @@ __global__ void flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel(Flash_bwd_params pa
 }
 
 template<typename Kernel_traits>
-__global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params) {
-    flash::convert_dQ<Kernel_traits>(params);
+__global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params, const int nsplits) {
+    flash::convert_dQ<Kernel_traits>(params, nsplits);
 }
 
 template<typename Kernel_traits>
@@ -49,9 +49,18 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
     const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
     dim3 grid_m(num_m_block, params.b, params.h);
     const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
-    dim3 grid_n(num_n_block, params.b, params.h);
+    int gridDimx = num_n_block;
+    if (params.deterministic) {
+        auto dprops = at::cuda::getCurrentDeviceProperties();
+        gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h);
+    }
+    dim3 grid_n(gridDimx, params.b, params.h);
 
-    flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
+    if (!params.deterministic) {
+        flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
+    } else {
+        flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
+    }
     C10_CUDA_KERNEL_LAUNCH_CHECK();
 
     // We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
@@ -69,6 +78,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
                         // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
                         // If Is_local, set Is_causal to false
                         auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
+                        // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
                         if (smem_size_dq_dk_dv >= 48 * 1024)  {
                             C10_CUDA_CHECK(cudaFuncSetAttribute(
                                 kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
@@ -86,7 +96,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
         C10_CUDA_CHECK(cudaFuncSetAttribute(
             kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
     }
-    kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params);
+    kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params, !params.deterministic ? 1 : gridDimx);
     C10_CUDA_KERNEL_LAUNCH_CHECK();
 }
 

+ 1 - 0
csrc/flash_attn/src/flash_fwd_launch_template.h

@@ -52,6 +52,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
                         // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
                         // If Is_local, set Is_causal to false
                         auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
+                        // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
                         // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
                         // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
                         if (smem_size >= 48 * 1024) {

+ 103 - 12
flash_attn/flash_attn_interface.py

@@ -122,6 +122,7 @@ def _flash_attn_backward(
     causal,
     window_size,
     alibi_slopes,
+    deterministic,
     rng_state=None,
 ):
     maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
@@ -143,6 +144,7 @@ def _flash_attn_backward(
         causal,
         window_size[0],
         window_size[1],
+        deterministic,
         None,
         rng_state,
     )
@@ -168,6 +170,7 @@ def _flash_attn_varlen_backward(
     causal,
     window_size,
     alibi_slopes,
+    deterministic,
     rng_state=None,
 ):
     maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
@@ -194,6 +197,7 @@ def _flash_attn_varlen_backward(
         causal,
         window_size[0],
         window_size[1],
+        deterministic,
         None,
         rng_state,
     )
@@ -205,7 +209,15 @@ def _flash_attn_varlen_backward(
 class FlashAttnQKVPackedFunc(torch.autograd.Function):
     @staticmethod
     def forward(
-        ctx, qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
+        ctx,
+        qkv,
+        dropout_p,
+        softmax_scale,
+        causal,
+        window_size,
+        alibi_slopes,
+        deterministic,
+        return_softmax,
     ):
         if softmax_scale is None:
             softmax_scale = qkv.shape[-1] ** (-0.5)
@@ -226,6 +238,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
         ctx.causal = causal
         ctx.window_size = window_size
         ctx.alibi_slopes = alibi_slopes
+        ctx.deterministic = deterministic
         return out if not return_softmax else (out, softmax_lse, S_dmask)
 
     @staticmethod
@@ -248,10 +261,11 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
             ctx.causal,
             ctx.window_size,
             ctx.alibi_slopes,
+            ctx.deterministic,
             rng_state=rng_state,
         )
         dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
-        return dqkv, None, None, None, None, None, None
+        return dqkv, None, None, None, None, None, None, None
 
 
 class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
@@ -266,6 +280,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
         causal,
         window_size,
         alibi_slopes,
+        deterministic,
         return_softmax,
     ):
         if softmax_scale is None:
@@ -292,6 +307,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
         ctx.causal = causal
         ctx.window_size = window_size
         ctx.alibi_slopes = alibi_slopes
+        ctx.deterministic = deterministic
         return out if not return_softmax else (out, softmax_lse, S_dmask)
 
     @staticmethod
@@ -318,16 +334,26 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
             ctx.causal,
             ctx.window_size,
             ctx.alibi_slopes,
+            ctx.deterministic,
             rng_state=rng_state,
         )
         dqkv = dqkv[..., : dout.shape[-1]]  # We could have padded the head dimension
-        return dqkv, None, None, None, None, None, None, None, None
+        return dqkv, None, None, None, None, None, None, None, None, None
 
 
 class FlashAttnKVPackedFunc(torch.autograd.Function):
     @staticmethod
     def forward(
-        ctx, q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
+        ctx,
+        q,
+        kv,
+        dropout_p,
+        softmax_scale,
+        causal,
+        window_size,
+        alibi_slopes,
+        deterministic,
+        return_softmax,
     ):
         if softmax_scale is None:
             softmax_scale = q.shape[-1] ** (-0.5)
@@ -348,6 +374,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
         ctx.causal = causal
         ctx.window_size = window_size
         ctx.alibi_slopes = alibi_slopes
+        ctx.deterministic = deterministic
         return out if not return_softmax else (out, softmax_lse, S_dmask)
 
     @staticmethod
@@ -371,11 +398,12 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
             ctx.causal,
             ctx.window_size,
             ctx.alibi_slopes,
+            ctx.deterministic,
             rng_state=rng_state,
         )
         dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
         dkv = dkv[..., : dout.shape[-1]]
-        return dq, dkv, None, None, None, None, None, None
+        return dq, dkv, None, None, None, None, None, None, None
 
 
 class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
@@ -393,6 +421,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
         causal,
         window_size,
         alibi_slopes,
+        deterministic,
         return_softmax,
     ):
         if softmax_scale is None:
@@ -422,6 +451,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
         ctx.causal = causal
         ctx.window_size = window_size
         ctx.alibi_slopes = alibi_slopes
+        ctx.deterministic = deterministic
         return out if not return_softmax else (out, softmax_lse, S_dmask)
 
     @staticmethod
@@ -449,17 +479,28 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
             ctx.causal,
             ctx.window_size,
             ctx.alibi_slopes,
+            ctx.deterministic,
             rng_state=rng_state,
         )
         dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
         dkv = dkv[..., : dout.shape[-1]]
-        return dq, dkv, None, None, None, None, None, None, None, None, None, None
+        return dq, dkv, None, None, None, None, None, None, None, None, None, None, None
 
 
 class FlashAttnFunc(torch.autograd.Function):
     @staticmethod
     def forward(
-        ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
+        ctx,
+        q,
+        k,
+        v,
+        dropout_p,
+        softmax_scale,
+        causal,
+        window_size,
+        alibi_slopes,
+        deterministic,
+        return_softmax,
     ):
         if softmax_scale is None:
             softmax_scale = q.shape[-1] ** (-0.5)
@@ -480,6 +521,7 @@ class FlashAttnFunc(torch.autograd.Function):
         ctx.causal = causal
         ctx.window_size = window_size
         ctx.alibi_slopes = alibi_slopes
+        ctx.deterministic = deterministic
         return out if not return_softmax else (out, softmax_lse, S_dmask)
 
     @staticmethod
@@ -501,12 +543,13 @@ class FlashAttnFunc(torch.autograd.Function):
             ctx.causal,
             ctx.window_size,
             ctx.alibi_slopes,
+            ctx.deterministic,
             rng_state=rng_state,
         )
         dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
         dk = dk[..., : dout.shape[-1]]
         dv = dv[..., : dout.shape[-1]]
-        return dq, dk, dv, None, None, None, None, None, None
+        return dq, dk, dv, None, None, None, None, None, None, None
 
 
 class FlashAttnVarlenFunc(torch.autograd.Function):
@@ -525,6 +568,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
         causal,
         window_size,
         alibi_slopes,
+        deterministic,
         return_softmax,
     ):
         if softmax_scale is None:
@@ -554,6 +598,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
         ctx.causal = causal
         ctx.window_size = window_size
         ctx.alibi_slopes = alibi_slopes
+        ctx.deterministic = deterministic
         return out if not return_softmax else (out, softmax_lse, S_dmask)
 
     @staticmethod
@@ -579,12 +624,13 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
             ctx.causal,
             ctx.window_size,
             ctx.alibi_slopes,
+            ctx.deterministic,
             rng_state=rng_state,
         )
         dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
         dk = dk[..., : dout.shape[-1]]
         dv = dv[..., : dout.shape[-1]]
-        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None
+        return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None
 
 
 def flash_attn_qkvpacked_func(
@@ -594,6 +640,7 @@ def flash_attn_qkvpacked_func(
     causal=False,
     window_size=(-1, -1),  # -1 means infinite context window
     alibi_slopes=None,
+    deterministic=False,
     return_attn_probs=False,
 ):
     """dropout_p should be set to 0.0 during evaluation
@@ -615,6 +662,8 @@ def flash_attn_qkvpacked_func(
         window_size: (left, right). If not (-1, -1), implements sliding window local attention.
         alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
             the attention score of query i and key j.
+        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
+            which is slightly slower and uses more memory. The forward pass is always deterministic.
         return_attn_probs: bool. Whether to return the attention probabilities. This option is for
            testing only. The returned probabilities are not guaranteed to be correct
            (they might not have the right scaling).
@@ -628,7 +677,14 @@ def flash_attn_qkvpacked_func(
             pattern (negative means that location was dropped, nonnegative means it was kept).
     """
     return FlashAttnQKVPackedFunc.apply(
-        qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs
+        qkv,
+        dropout_p,
+        softmax_scale,
+        causal,
+        window_size,
+        alibi_slopes,
+        deterministic,
+        return_attn_probs,
     )
 
 
@@ -640,6 +696,7 @@ def flash_attn_kvpacked_func(
     causal=False,
     window_size=(-1, -1),  # -1 means infinite context window
     alibi_slopes=None,
+    deterministic=False,
     return_attn_probs=False,
 ):
     """dropout_p should be set to 0.0 during evaluation
@@ -678,6 +735,8 @@ def flash_attn_kvpacked_func(
         alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
             (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
             is added to the attention score of query i and key j.
+        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
+            which is slightly slower and uses more memory. The forward pass is always deterministic.
         return_attn_probs: bool. Whether to return the attention probabilities. This option is for
            testing only. The returned probabilities are not guaranteed to be correct
            (they might not have the right scaling).
@@ -691,7 +750,15 @@ def flash_attn_kvpacked_func(
             pattern (negative means that location was dropped, nonnegative means it was kept).
     """
     return FlashAttnKVPackedFunc.apply(
-        q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs
+        q,
+        kv,
+        dropout_p,
+        softmax_scale,
+        causal,
+        window_size,
+        alibi_slopes,
+        deterministic,
+        return_attn_probs,
     )
 
 
@@ -704,6 +771,7 @@ def flash_attn_func(
     causal=False,
     window_size=(-1, -1),  # -1 means infinite context window
     alibi_slopes=None,
+    deterministic=False,
     return_attn_probs=False,
 ):
     """dropout_p should be set to 0.0 during evaluation
@@ -740,6 +808,8 @@ def flash_attn_func(
         alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
             (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
             is added to the attention score of query i and key j.
+        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
+            which is slightly slower and uses more memory. The forward pass is always deterministic.
         return_attn_probs: bool. Whether to return the attention probabilities. This option is for
            testing only. The returned probabilities are not guaranteed to be correct
            (they might not have the right scaling).
@@ -753,7 +823,16 @@ def flash_attn_func(
             pattern (negative means that location was dropped, nonnegative means it was kept).
     """
     return FlashAttnFunc.apply(
-        q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs
+        q,
+        k,
+        v,
+        dropout_p,
+        softmax_scale,
+        causal,
+        window_size,
+        alibi_slopes,
+        deterministic,
+        return_attn_probs,
     )
 
 
@@ -766,6 +845,7 @@ def flash_attn_varlen_qkvpacked_func(
     causal=False,
     window_size=(-1, -1),  # -1 means infinite context window
     alibi_slopes=None,
+    deterministic=False,
     return_attn_probs=False,
 ):
     """dropout_p should be set to 0.0 during evaluation
@@ -790,6 +870,8 @@ def flash_attn_varlen_qkvpacked_func(
         window_size: (left, right). If not (-1, -1), implements sliding window local attention.
         alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
             is added to the attention score of query i and key j.
+        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
+            which is slightly slower and uses more memory. The forward pass is always deterministic.
         return_attn_probs: bool. Whether to return the attention probabilities. This option is for
            testing only. The returned probabilities are not guaranteed to be correct
            (they might not have the right scaling).
@@ -811,6 +893,7 @@ def flash_attn_varlen_qkvpacked_func(
         causal,
         window_size,
         alibi_slopes,
+        deterministic,
         return_attn_probs,
     )
 
@@ -827,6 +910,7 @@ def flash_attn_varlen_kvpacked_func(
     causal=False,
     window_size=(-1, -1),  # -1 means infinite context window
     alibi_slopes=None,
+    deterministic=False,
     return_attn_probs=False,
 ):
     """dropout_p should be set to 0.0 during evaluation
@@ -871,6 +955,8 @@ def flash_attn_varlen_kvpacked_func(
         alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
             (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
             is added to the attention score of query i and key j.
+        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
+            which is slightly slower and uses more memory. The forward pass is always deterministic.
         return_attn_probs: bool. Whether to return the attention probabilities. This option is for
            testing only. The returned probabilities are not guaranteed to be correct
            (they might not have the right scaling).
@@ -895,6 +981,7 @@ def flash_attn_varlen_kvpacked_func(
         causal,
         window_size,
         alibi_slopes,
+        deterministic,
         return_attn_probs,
     )
 
@@ -912,6 +999,7 @@ def flash_attn_varlen_func(
     causal=False,
     window_size=(-1, -1),  # -1 means infinite context window
     alibi_slopes=None,
+    deterministic=False,
     return_attn_probs=False,
 ):
     """dropout_p should be set to 0.0 during evaluation
@@ -954,6 +1042,8 @@ def flash_attn_varlen_func(
         alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
             (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
             is added to the attention score of query i and key j.
+        deterministic: bool. Whether to use the deterministic implementation of the backward pass,
+            which is slightly slower and uses more memory. The forward pass is always deterministic.
         return_attn_probs: bool. Whether to return the attention probabilities. This option is for
            testing only. The returned probabilities are not guaranteed to be correct
            (they might not have the right scaling).
@@ -979,6 +1069,7 @@ def flash_attn_varlen_func(
         causal,
         window_size,
         alibi_slopes,
+        deterministic,
         return_attn_probs,
     )
 

+ 177 - 11
tests/test_flash_attn.py

@@ -566,10 +566,12 @@ def get_dropout_fraction(
 
 @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
 # @pytest.mark.parametrize("dtype", [torch.float16])
+@pytest.mark.parametrize("deterministic", [False, True])
+# @pytest.mark.parametrize("deterministic", [True])
 @pytest.mark.parametrize("alibi", [False, True])
-# @pytest.mark.parametrize("alibi", [True])
+# @pytest.mark.parametrize("alibi", [False])
 @pytest.mark.parametrize("local", [False, True])
-# @pytest.mark.parametrize("local", [True])
+# @pytest.mark.parametrize("local", [False])
 @pytest.mark.parametrize("causal", [False, True])
 # @pytest.mark.parametrize("causal", [False])
 @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
@@ -578,16 +580,16 @@ def get_dropout_fraction(
 # @pytest.mark.parametrize("d", [64])
 # @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
 @pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048])
-# @pytest.mark.parametrize("seqlen", [97])
+# @pytest.mark.parametrize("seqlen", [512])
 @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
 # @pytest.mark.parametrize("dropout_p", [0.0])
-def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype):
+def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
     if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
         pytest.skip()  # Reference implementation OOM
     device = "cuda"
     # set seed
     torch.random.manual_seed(0)
-    batch_size = 8
+    batch_size = 4
     nheads = 9
     window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
     qkv = torch.randn(
@@ -604,6 +606,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype)
         causal=causal,
         window_size=window_size,
         alibi_slopes=alibi_slopes,
+        deterministic=deterministic,
         return_attn_probs=True,
     )
     if dropout_p > 0.0:
@@ -712,6 +715,8 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype)
 
 @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
 # @pytest.mark.parametrize('dtype', [torch.float16])
+@pytest.mark.parametrize("deterministic", [False, True])
+# @pytest.mark.parametrize("deterministic", [True])
 @pytest.mark.parametrize("alibi", [False, True])
 # @pytest.mark.parametrize("alibi", [True])
 @pytest.mark.parametrize("local", [False, True])
@@ -725,7 +730,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype)
 # @pytest.mark.parametrize('seqlen', [128])
 @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
 # @pytest.mark.parametrize('dropout_p', [0.0])
-def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype):
+def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
     if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
         pytest.skip()  # Reference implementation OOM
     device = "cuda"
@@ -760,6 +765,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi,
         causal=causal,
         window_size=window_size,
         alibi_slopes=alibi_slopes,
+        deterministic=deterministic,
         return_attn_probs=True,
     )
     out = output_pad_fn(out_unpad)
@@ -859,6 +865,8 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi,
 # @pytest.mark.parametrize("dtype", [torch.bfloat16])
 @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
 # @pytest.mark.parametrize("mha_type", ["mha"])
+@pytest.mark.parametrize("deterministic", [False, True])
+# @pytest.mark.parametrize("deterministic", [True])
 @pytest.mark.parametrize("alibi", [False, True])
 # @pytest.mark.parametrize("alibi", [True])
 @pytest.mark.parametrize("local", [False, True])
@@ -890,7 +898,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi,
 @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
 # @pytest.mark.parametrize("dropout_p", [0.17])
 def test_flash_attn_output(
-    seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, mha_type, dtype, kvpacked
+    seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
 ):
     if (
         max(seqlen_q, seqlen_k) >= 2048
@@ -900,7 +908,7 @@ def test_flash_attn_output(
     device = "cuda"
     # set seed
     torch.random.manual_seed(0)
-    batch_size = 8
+    batch_size = 4
     nheads = 9
     nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
     assert nheads % nheads_k == 0
@@ -931,6 +939,7 @@ def test_flash_attn_output(
             causal=causal,
             window_size=window_size,
             alibi_slopes=alibi_slopes,
+            deterministic=deterministic,
             return_attn_probs=True,
         )
     else:
@@ -942,6 +951,7 @@ def test_flash_attn_output(
             causal=causal,
             window_size=window_size,
             alibi_slopes=alibi_slopes,
+            deterministic=deterministic,
             return_attn_probs=True,
         )
     if dropout_p > 0.0:
@@ -1114,6 +1124,8 @@ def test_flash_attn_output(
 # @pytest.mark.parametrize('dtype', [torch.float16])
 @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
 # @pytest.mark.parametrize('mha_type', ["mqa"])
+@pytest.mark.parametrize("deterministic", [False, True])
+# @pytest.mark.parametrize("deterministic", [True])
 @pytest.mark.parametrize("alibi", [False, True])
 # @pytest.mark.parametrize("alibi", [True])
 @pytest.mark.parametrize("local", [False, True])
@@ -1143,7 +1155,7 @@ def test_flash_attn_output(
 @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
 # @pytest.mark.parametrize('dropout_p', [0.0])
 def test_flash_attn_varlen_output(
-    seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, mha_type, dtype, kvpacked
+    seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
 ):
     if (
         max(seqlen_q, seqlen_k) >= 2048
@@ -1153,7 +1165,7 @@ def test_flash_attn_varlen_output(
     device = "cuda"
     # set seed
     torch.random.manual_seed(0)
-    batch_size = 8
+    batch_size = 4
     nheads = 9
     nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
     assert nheads % nheads_k == 0
@@ -1207,6 +1219,7 @@ def test_flash_attn_varlen_output(
             causal=causal,
             window_size=window_size,
             alibi_slopes=alibi_slopes,
+            deterministic=deterministic,
             return_attn_probs=True,
         )
     else:
@@ -1237,6 +1250,7 @@ def test_flash_attn_varlen_output(
             causal=causal,
             window_size=window_size,
             alibi_slopes=alibi_slopes,
+            deterministic=deterministic,
             return_attn_probs=True,
         )
     out = output_pad_fn(out_unpad)
@@ -1675,6 +1689,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
 
 @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
 # @pytest.mark.parametrize("dtype", [torch.float16])
+@pytest.mark.parametrize("deterministic", [False, True])
+# @pytest.mark.parametrize("deterministic", [True])
 @pytest.mark.parametrize("alibi", [False, True])
 # @pytest.mark.parametrize("alibi", [True])
 @pytest.mark.parametrize("local", [False, True])
@@ -1704,7 +1720,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
     ],
 )
 # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
-def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, dtype):
+def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype):
     if swap_sq_sk:
         seqlen_q, seqlen_k = seqlen_k, seqlen_q
     device = "cuda"
@@ -1729,6 +1745,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, al
         causal=causal,
         window_size=window_size,
         alibi_slopes=alibi_slopes,
+        deterministic=deterministic,
         return_attn_probs=True,
     )
     out_ref, attn_ref = attention_ref(
@@ -2224,3 +2241,152 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype):
     assert not q.grad.isnan().any()
     assert not k.grad.isnan().any()
     assert not v.grad.isnan().any()
+
+
+@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
+# @pytest.mark.parametrize("dtype", [torch.bfloat16])
+@pytest.mark.parametrize("local", [False, True])
+# @pytest.mark.parametrize("local", [True])
+@pytest.mark.parametrize("causal", [False, True])
+# @pytest.mark.parametrize("causal", [True])
+@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
+# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
+# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
+# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
+# @pytest.mark.parametrize('d', [56, 80])
+# @pytest.mark.parametrize("d", [64])
+@pytest.mark.parametrize("swap_sq_sk", [False, True])
+# @pytest.mark.parametrize("swap_sq_sk", [False])
+@pytest.mark.parametrize(
+    "seqlen_q,seqlen_k",
+    [
+        (1, 239),
+        (3, 799),
+        (127, 512),
+        (127, 513),
+        (113, 203),
+        (128, 217),
+        (113, 211),
+        (108, 256),
+        (256, 512),
+        (1023, 1024),
+    ],
+)
+# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
+def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):
+    if (
+        max(seqlen_q, seqlen_k) >= 2048
+        and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
+    ):
+        pytest.skip()  # Reference implementation OOM
+    if swap_sq_sk:
+        seqlen_q, seqlen_k = seqlen_k, seqlen_q
+    device = "cuda"
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 4
+    nheads = 9
+    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
+    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True)
+
+    g = torch.randn_like(out)
+    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
+        dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
+        for _ in range(50):
+            dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
+            assert torch.equal(dv, dv0)
+            assert torch.equal(dk, dk0)
+            assert torch.equal(dq, dq0)
+
+
+
+
+@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
+# @pytest.mark.parametrize("dtype", [torch.bfloat16])
+@pytest.mark.parametrize("local", [False, True])
+# @pytest.mark.parametrize("local", [True])
+@pytest.mark.parametrize("causal", [False, True])
+# @pytest.mark.parametrize("causal", [True])
+@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
+# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
+# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
+# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
+# @pytest.mark.parametrize('d', [56, 80])
+# @pytest.mark.parametrize("d", [64])
+@pytest.mark.parametrize("swap_sq_sk", [False, True])
+# @pytest.mark.parametrize("swap_sq_sk", [True])
+@pytest.mark.parametrize(
+    "seqlen_q,seqlen_k",
+    [
+        (1, 239),
+        (3, 799),
+        (127, 512),
+        (127, 513),
+        (113, 203),
+        (128, 217),
+        (113, 211),
+        (108, 256),
+        (256, 512),
+        (1023, 1024),
+    ],
+)
+# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
+def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):
+    if (
+        max(seqlen_q, seqlen_k) >= 2048
+        and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
+    ):
+        pytest.skip()  # Reference implementation OOM
+    if swap_sq_sk:
+        seqlen_q, seqlen_k = seqlen_k, seqlen_q
+    device = "cuda"
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 2
+    nheads = 9
+    window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
+    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
+    key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
+    (
+        q_unpad,
+        k_unpad,
+        v_unpad,
+        cu_seqlens_q,
+        cu_seqlens_k,
+        max_seqlen_q,
+        max_seqlen_k,
+        q,
+        k,
+        v,
+        output_pad_fn,
+        dq_pad_fn,
+        dk_pad_fn,
+    ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
+    out = flash_attn_varlen_func(
+        q_unpad,
+        k_unpad,
+        v_unpad,
+        cu_seqlens_q,
+        cu_seqlens_k,
+        max_seqlen_q,
+        max_seqlen_k,
+        0.0,
+        causal=causal,
+        window_size=window_size,
+        deterministic=True,
+    )
+
+    g = torch.randn_like(out)
+    if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
+        dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
+        for _ in range(50):
+            dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
+            assert torch.equal(dv, dv)
+            assert torch.equal(dk, dk)
+            assert torch.equal(dq, dq)