소스 검색

Implement softcapping. (#1025)

* Softcap v2 (fwd only).

* Some missing interface + remove overrides in tests.
Nicolas Patry 8 달 전
부모
커밋
8f873cc6ac

+ 30 - 4
csrc/flash_attn/flash_api.cpp

@@ -43,6 +43,7 @@ void set_params_fprop(Flash_fwd_params &params,
                       float softmax_scale,
                       int window_size_left,
                       int window_size_right,
+                      const float softcap,
                       bool seqlenq_ngroups_swapped=false,
                       const bool unpadded_lse=false) {
 
@@ -100,8 +101,19 @@ void set_params_fprop(Flash_fwd_params &params,
     params.d_rounded = d_rounded;
 
     // Set the different scale values.
-    params.scale_softmax = softmax_scale;
-    params.scale_softmax_log2 = softmax_scale * M_LOG2E;
+    #ifdef FLASHATTENTION_DISABLE_SOFTCAP
+        TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap.");
+    #endif
+    if (softcap > 0.0) {
+        params.softcap = softmax_scale / softcap;
+        params.scale_softmax =  softcap;
+        params.scale_softmax_log2 = softcap * M_LOG2E;
+    }else{
+        // Remove potential NaN
+        params.softcap = 0.0;
+        params.scale_softmax = softmax_scale;
+        params.scale_softmax_log2 = softmax_scale * M_LOG2E;
+    }
 
     // Set this to probability of keeping an element to simplify things.
     params.p_dropout = 1.f - p_dropout;
@@ -172,6 +184,7 @@ void set_params_dgrad(Flash_bwd_params &params,
                       float softmax_scale,
                       int window_size_left,
                       int window_size_right,
+                      const float softcap,
                       bool deterministic,
                       const bool unpadded_lse) {
 
@@ -187,6 +200,7 @@ void set_params_dgrad(Flash_bwd_params &params,
                      softmax_scale,
                      window_size_left,
                      window_size_right,
+                     softcap,
                      false, // seqlenq_ngroups_swapped
                      unpadded_lse);
 
@@ -332,6 +346,7 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
         bool is_causal,
         int window_size_left,
         int window_size_right,
+        const float softcap,
         const bool return_softmax,
         c10::optional<at::Generator> gen_) {
 
@@ -453,7 +468,9 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
                      p_dropout,
                      softmax_scale,
                      window_size_left,
-                     window_size_right);
+                     window_size_right,
+                     softcap
+                     );
 
 
     set_params_splitkv(params, batch_size, num_heads,
@@ -521,6 +538,7 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
                bool is_causal,
                int window_size_left,
                int window_size_right,
+               const float softcap,
                const bool return_softmax,
                c10::optional<at::Generator> gen_) {
 
@@ -688,6 +706,7 @@ mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \s
                      softmax_scale,
                      window_size_left,
                      window_size_right,
+                     softcap,
                      seqlenq_ngroups_swapped,
                      /*unpadded_lse*/true);
     params.total_q = total_q;
@@ -776,6 +795,7 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
         const bool is_causal,
         int window_size_left,
         int window_size_right,
+        const float softcap,
         const bool deterministic,
         c10::optional<at::Generator> gen_,
         c10::optional<at::Tensor> &rng_state) {
@@ -940,6 +960,7 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
                      softmax_scale,
                      window_size_left,
                      window_size_right,
+                     softcap,
                      deterministic,
                      /*unpadded_lse*/false);
     params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
@@ -1009,6 +1030,7 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
                const bool is_causal,
                int window_size_left,
                int window_size_right,
+               const float softcap,
                const bool deterministic,
                c10::optional<at::Generator> gen_,
                c10::optional<at::Tensor> &rng_state) {
@@ -1191,6 +1213,7 @@ mha_varlen_bwd(const at::Tensor &dout,  // total_q x num_heads, x head_size
                      softmax_scale,
                      window_size_left,
                      window_size_right,
+                     softcap,
                      deterministic,
                      /*unpadded_lse*/true);
     params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
@@ -1257,6 +1280,7 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
                 bool is_causal,
                 int window_size_left,
                 int window_size_right,
+                const float softcap,
                 bool is_rotary_interleaved,   // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
                 int num_splits
                 ) {
@@ -1392,7 +1416,9 @@ mha_fwd_kvcache(at::Tensor &q,                 // batch_size x seqlen_q x num_he
                      /*p_dropout=*/0.f,
                      softmax_scale,
                      window_size_left,
-                     window_size_right);
+                     window_size_right,
+                     softcap
+                     );
 
     at::Tensor k, v, k_padded, v_padded;
     if (k_.has_value()) {

+ 1 - 0
csrc/flash_attn/src/flash.h

@@ -118,6 +118,7 @@ struct Flash_fwd_params : public Qkv_params {
 
     // Local window size
     int window_size_left, window_size_right;
+    float softcap;
 
     // Random state.
     at::PhiloxCudaState philox_args;

+ 35 - 6
csrc/flash_attn/src/flash_fwd_kernel.h

@@ -22,6 +22,22 @@ namespace flash {
 
 using namespace cute;
 
+template <typename Engine, typename Layout>
+__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
+    static_assert(Layout::rank == 3, "Only support 3D Tensor");
+    static_assert(decltype(size<0>(tensor))::value == 4, "First dimension must be 4");
+    #pragma unroll
+    for (int i=0; i < size<0>(tensor); ++i){  // MMA
+        #pragma unroll
+        for (int mi=0; mi < size<1>(tensor); ++mi){
+            #pragma unroll
+            for (int nj=0; nj < size<2>(tensor); ++nj){
+                tensor(i, mi, nj) = cutlass::fast_tanh(tensor(i, mi, nj) * softcap );
+            }
+        }
+    }
+}
+
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
@@ -45,7 +61,7 @@ __forceinline__ __device__ auto get_lse_tile(const Params &params, const int bid
 }
 
 
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
 inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
 
     using Element = typename Kernel_traits::Element;
@@ -318,6 +334,9 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
             smem_thr_copy_Q, smem_thr_copy_K
         );
         // if (cute::thread0()) { print(acc_s); }
+        if constexpr (Is_softcap){
+            apply_softcap(acc_s, params.softcap);
+        }
 
         mask.template apply_mask<Is_causal, Is_even_MN>(
             acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
@@ -381,6 +400,9 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
             acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
             smem_thr_copy_Q, smem_thr_copy_K
         );
+        if constexpr (Is_softcap){
+            apply_softcap(acc_s, params.softcap);
+        }
 
         flash::cp_async_wait<0>();
         __syncthreads();
@@ -486,7 +508,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
+template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, typename Params>
 inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
 
     using Element = typename Kernel_traits::Element;
@@ -870,6 +892,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
             smem_thr_copy_Q, smem_thr_copy_K
         );
         // if (cute::thread0()) { print(acc_s); }
+        if constexpr (Is_softcap){
+            apply_softcap(acc_s, params.softcap);
+        }
+
 
         mask.template apply_mask<Is_causal, Is_even_MN>(
             acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
@@ -941,6 +967,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
             acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
             smem_thr_copy_Q, smem_thr_copy_K
         );
+        if constexpr (Is_softcap){
+            apply_softcap(acc_s, params.softcap);
+        }
 
         flash::cp_async_wait<0>();
         __syncthreads();
@@ -1054,7 +1083,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
 inline __device__ void compute_attn(const Params &params) {
     const int m_block = blockIdx.x;
     // The block index for the batch.
@@ -1070,12 +1099,12 @@ inline __device__ void compute_attn(const Params &params) {
     // the attention matrix. This way, as long as we have the batch, head, and the location of
     // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
 
-    flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
+    flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params, bidb, bidh, m_block);
 }
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
+template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, typename Params>
 inline __device__ void compute_attn_splitkv(const Params &params) {
     const int m_block = blockIdx.x;
     // The block index for the batch.
@@ -1084,7 +1113,7 @@ inline __device__ void compute_attn_splitkv(const Params &params) {
     const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
     const int n_split_idx = Split ? blockIdx.y : 0;
     const int num_n_splits = Split ? gridDim.y : 1;
-    flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
+    flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
 }
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////

+ 39 - 35
csrc/flash_attn/src/flash_fwd_launch_template.h

@@ -26,18 +26,18 @@
 template<typename Kernel_traits, __VA_ARGS__> \
 __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
 
-DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax) {
+DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {
     #if defined(ARCH_SUPPORTS_FLASH)
         static_assert(!(Is_causal && Is_local)); // Enforce constraints
-        flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
+        flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);
     #else
         FLASH_UNSUPPORTED_ARCH
     #endif
 }
 
-DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV) {
+DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) {
     #if defined(ARCH_SUPPORTS_FLASH)
-        flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params);
+        flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params);
     #else
         FLASH_UNSUPPORTED_ARCH
     #endif
@@ -67,25 +67,27 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
             LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
                 BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
                     ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
-                        // Will only return softmax if dropout, to reduce compilation time.
-                        // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
-                        // If return_softmax, set IsEvenMNConst to false to reduce number of templates
-                        // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
-                        // If Is_local, set Is_causal to false
-                        auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
-                        // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
-                        // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
-                        // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
-                        if (smem_size >= 48 * 1024) {
-                            C10_CUDA_CHECK(cudaFuncSetAttribute(
-                                kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
-                        }
-                        // int ctas_per_sm;
-                        // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
-                        //     &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
-                        // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
-                        kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
-                        C10_CUDA_KERNEL_LAUNCH_CHECK();
+                        SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
+                            // Will only return softmax if dropout, to reduce compilation time.
+                            // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
+                            // If return_softmax, set IsEvenMNConst to false to reduce number of templates
+                            // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
+                            // If Is_local, set Is_causal to false
+                            auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout>;
+                            // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
+                            // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
+                            // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
+                            if (smem_size >= 48 * 1024) {
+                                C10_CUDA_CHECK(cudaFuncSetAttribute(
+                                    kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+                            }
+                            // int ctas_per_sm;
+                            // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+                            //     &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
+                            // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
+                            kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
+                            C10_CUDA_KERNEL_LAUNCH_CHECK();
+                        });
                     });
                 });
             });
@@ -109,18 +111,20 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
                     BOOL_SWITCH(params.num_splits > 1, Split, [&] {
                         BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
                             ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
-                                // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
-                                // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
-                                // If Is_local, set Is_causal to false
-                                auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>;
-                                // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
-                                // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
-                                if (smem_size >= 48 * 1024) {
-                                    C10_CUDA_CHECK(cudaFuncSetAttribute(
-                                        kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
-                                }
-                                kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
-                                C10_CUDA_KERNEL_LAUNCH_CHECK();
+                                SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
+                                    // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
+                                    // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
+                                    // If Is_local, set Is_causal to false
+                                    auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV>;
+                                    // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
+                                    // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
+                                    if (smem_size >= 48 * 1024) {
+                                        C10_CUDA_CHECK(cudaFuncSetAttribute(
+                                            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+                                    }
+                                    kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
+                                    C10_CUDA_KERNEL_LAUNCH_CHECK();
+                                });
                             });
                         });
                     });

+ 10 - 0
csrc/flash_attn/src/static_switch.h

@@ -56,6 +56,16 @@
   #define EVENK_SWITCH BOOL_SWITCH
 #endif
 
+#ifdef FLASHATTENTION_DISABLE_SOFTCAP
+  #define SOFTCAP_SWITCH(COND, CONST_NAME, ...)   \
+  [&] {                                         \
+    constexpr static bool CONST_NAME = false;    \
+    return __VA_ARGS__();                       \
+  }()
+#else
+  #define SOFTCAP_SWITCH BOOL_SWITCH
+#endif
+
 #ifdef FLASHATTENTION_DISABLE_LOCAL
   #define LOCAL_SWITCH(COND, CONST_NAME, ...)   \
   [&] {                                         \

+ 42 - 1
flash_attn/flash_attn_interface.py

@@ -44,7 +44,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
 
 
 def _flash_attn_forward(
-    q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
+    q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax
 ):
     maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
     q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
@@ -59,6 +59,7 @@ def _flash_attn_forward(
         causal,
         window_size[0],
         window_size[1],
+        softcap,
         return_softmax,
         None,
     )
@@ -123,6 +124,7 @@ def _flash_attn_backward(
     softmax_scale,
     causal,
     window_size,
+    softcap,
     alibi_slopes,
     deterministic,
     rng_state=None,
@@ -151,6 +153,7 @@ def _flash_attn_backward(
         causal,
         window_size[0],
         window_size[1],
+        softcap,
         deterministic,
         None,
         rng_state,
@@ -176,6 +179,7 @@ def _flash_attn_varlen_backward(
     softmax_scale,
     causal,
     window_size,
+    softcap,
     alibi_slopes,
     deterministic,
     rng_state=None,
@@ -209,6 +213,7 @@ def _flash_attn_varlen_backward(
         causal,
         window_size[0],
         window_size[1],
+        softcap,
         deterministic,
         None,
         rng_state,
@@ -227,6 +232,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
         softmax_scale,
         causal,
         window_size,
+        softcap,
         alibi_slopes,
         deterministic,
         return_softmax,
@@ -241,6 +247,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
             softmax_scale,
             causal=causal,
             window_size=window_size,
+            softcap=softcap,
             alibi_slopes=alibi_slopes,
             return_softmax=return_softmax and dropout_p > 0,
         )
@@ -249,6 +256,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
         ctx.softmax_scale = softmax_scale
         ctx.causal = causal
         ctx.window_size = window_size
+        ctx.softcap = softcap
         ctx.alibi_slopes = alibi_slopes
         ctx.deterministic = deterministic
         return out if not return_softmax else (out, softmax_lse, S_dmask)
@@ -272,6 +280,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
             ctx.softmax_scale,
             ctx.causal,
             ctx.window_size,
+            ctx.softcap,
             ctx.alibi_slopes,
             ctx.deterministic,
             rng_state=rng_state,
@@ -433,6 +442,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
         softmax_scale,
         causal,
         window_size,
+        softcap,
         alibi_slopes,
         deterministic,
         return_softmax,
@@ -451,6 +461,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
             softmax_scale,
             causal=causal,
             window_size=window_size,
+            softcap=softcap,
             alibi_slopes=alibi_slopes,
             return_softmax=return_softmax and dropout_p > 0,
             block_table=None,
@@ -464,6 +475,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
         ctx.softmax_scale = softmax_scale
         ctx.causal = causal
         ctx.window_size = window_size
+        ctx.softcap = softcap
         ctx.alibi_slopes = alibi_slopes
         ctx.deterministic = deterministic
         return out if not return_softmax else (out, softmax_lse, S_dmask)
@@ -492,6 +504,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
             ctx.softmax_scale,
             ctx.causal,
             ctx.window_size,
+            ctx.softcap,
             ctx.alibi_slopes,
             ctx.deterministic,
             rng_state=rng_state,
@@ -512,6 +525,7 @@ class FlashAttnFunc(torch.autograd.Function):
         softmax_scale,
         causal,
         window_size,
+        softcap,
         alibi_slopes,
         deterministic,
         return_softmax,
@@ -526,6 +540,7 @@ class FlashAttnFunc(torch.autograd.Function):
             softmax_scale,
             causal=causal,
             window_size=window_size,
+            softcap=softcap,
             alibi_slopes=alibi_slopes,
             return_softmax=return_softmax and dropout_p > 0,
         )
@@ -534,6 +549,7 @@ class FlashAttnFunc(torch.autograd.Function):
         ctx.softmax_scale = softmax_scale
         ctx.causal = causal
         ctx.window_size = window_size
+        ctx.softcap = softcap
         ctx.alibi_slopes = alibi_slopes
         ctx.deterministic = deterministic
         return out if not return_softmax else (out, softmax_lse, S_dmask)
@@ -556,6 +572,7 @@ class FlashAttnFunc(torch.autograd.Function):
             ctx.softmax_scale,
             ctx.causal,
             ctx.window_size,
+            ctx.softcap
             ctx.alibi_slopes,
             ctx.deterministic,
             rng_state=rng_state,
@@ -581,6 +598,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
         softmax_scale,
         causal,
         window_size,
+        softcap,
         alibi_slopes,
         deterministic,
         return_softmax,
@@ -600,6 +618,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
             softmax_scale,
             causal=causal,
             window_size=window_size,
+            softcap=softcap,
             alibi_slopes=alibi_slopes,
             return_softmax=return_softmax and dropout_p > 0,
             block_table=block_table,
@@ -613,6 +632,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
         ctx.softmax_scale = softmax_scale
         ctx.causal = causal
         ctx.window_size = window_size
+        ctx.softcap = softcap
         ctx.alibi_slopes = alibi_slopes
         ctx.deterministic = deterministic
         return out if not return_softmax else (out, softmax_lse, S_dmask)
@@ -639,6 +659,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
             ctx.softmax_scale,
             ctx.causal,
             ctx.window_size,
+            ctx.softcap
             ctx.alibi_slopes,
             ctx.deterministic,
             rng_state=rng_state,
@@ -655,6 +676,7 @@ def flash_attn_qkvpacked_func(
     softmax_scale=None,
     causal=False,
     window_size=(-1, -1),  # -1 means infinite context window
+    softcap=0.0,  # <=0.0 means deactivate
     alibi_slopes=None,
     deterministic=False,
     return_attn_probs=False,
@@ -676,6 +698,7 @@ def flash_attn_qkvpacked_func(
             Default to 1 / sqrt(headdim).
         causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
         window_size: (left, right). If not (-1, -1), implements sliding window local attention.
+        softcap: float. Anything > 0 activates softcapping attention.
         alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
             the attention score of query i and key j.
         deterministic: bool. Whether to use the deterministic implementation of the backward pass,
@@ -698,6 +721,7 @@ def flash_attn_qkvpacked_func(
         softmax_scale,
         causal,
         window_size,
+        softcapping,
         alibi_slopes,
         deterministic,
         return_attn_probs,
@@ -711,6 +735,7 @@ def flash_attn_kvpacked_func(
     softmax_scale=None,
     causal=False,
     window_size=(-1, -1),  # -1 means infinite context window
+    softcap=0.0,  # 0.0 means deactivated
     alibi_slopes=None,
     deterministic=False,
     return_attn_probs=False,
@@ -748,6 +773,7 @@ def flash_attn_kvpacked_func(
             Default to 1 / sqrt(headdim).
         causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
         window_size: (left, right). If not (-1, -1), implements sliding window local attention.
+        softcap: float. Anything > 0 activates softcapping attention.
         alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
             (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
             is added to the attention score of query i and key j.
@@ -772,6 +798,7 @@ def flash_attn_kvpacked_func(
         softmax_scale,
         causal,
         window_size,
+        softcap,
         alibi_slopes,
         deterministic,
         return_attn_probs,
@@ -786,6 +813,7 @@ def flash_attn_func(
     softmax_scale=None,
     causal=False,
     window_size=(-1, -1),  # -1 means infinite context window
+    softcap=0.0, # 0.0 means deactivated
     alibi_slopes=None,
     deterministic=False,
     return_attn_probs=False,
@@ -846,6 +874,7 @@ def flash_attn_func(
         softmax_scale,
         causal,
         window_size,
+        softcap,
         alibi_slopes,
         deterministic,
         return_attn_probs,
@@ -860,6 +889,7 @@ def flash_attn_varlen_qkvpacked_func(
     softmax_scale=None,
     causal=False,
     window_size=(-1, -1),  # -1 means infinite context window
+    softcap=0.0, # 0.0 means deactivated
     alibi_slopes=None,
     deterministic=False,
     return_attn_probs=False,
@@ -884,6 +914,7 @@ def flash_attn_varlen_qkvpacked_func(
             Default to 1 / sqrt(headdim).
         causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
         window_size: (left, right). If not (-1, -1), implements sliding window local attention.
+        softcap: float. Anything > 0 activates softcapping attention.
         alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
             is added to the attention score of query i and key j.
         deterministic: bool. Whether to use the deterministic implementation of the backward pass,
@@ -908,6 +939,7 @@ def flash_attn_varlen_qkvpacked_func(
         softmax_scale,
         causal,
         window_size,
+        softcap,
         alibi_slopes,
         deterministic,
         return_attn_probs,
@@ -925,6 +957,7 @@ def flash_attn_varlen_kvpacked_func(
     softmax_scale=None,
     causal=False,
     window_size=(-1, -1),  # -1 means infinite context window
+    softcap=0.0, # 0.0 means deactivated
     alibi_slopes=None,
     deterministic=False,
     return_attn_probs=False,
@@ -968,6 +1001,7 @@ def flash_attn_varlen_kvpacked_func(
             Default to 1 / sqrt(headdim).
         causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
         window_size: (left, right). If not (-1, -1), implements sliding window local attention.
+        softcap: float. Anything > 0 activates softcapping attention.
         alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
             (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
             is added to the attention score of query i and key j.
@@ -996,6 +1030,7 @@ def flash_attn_varlen_kvpacked_func(
         softmax_scale,
         causal,
         window_size,
+        softcap,
         alibi_slopes,
         deterministic,
         return_attn_probs,
@@ -1014,6 +1049,7 @@ def flash_attn_varlen_func(
     softmax_scale=None,
     causal=False,
     window_size=(-1, -1),  # -1 means infinite context window
+    softcap=0.0, # 0.0 means deactivated
     alibi_slopes=None,
     deterministic=False,
     return_attn_probs=False,
@@ -1056,6 +1092,7 @@ def flash_attn_varlen_func(
             Default to 1 / sqrt(headdim).
         causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
         window_size: (left, right). If not (-1, -1), implements sliding window local attention.
+        softcap: float. Anything > 0 activates softcapping attention.
         alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
             (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
             is added to the attention score of query i and key j.
@@ -1085,6 +1122,7 @@ def flash_attn_varlen_func(
         softmax_scale,
         causal,
         window_size,
+        softcap,
         alibi_slopes,
         deterministic,
         return_attn_probs,
@@ -1106,6 +1144,7 @@ def flash_attn_with_kvcache(
     softmax_scale=None,
     causal=False,
     window_size=(-1, -1),  # -1 means infinite context window
+    softcap=0.0, # 0.0 means deactivated
     rotary_interleaved=True,
     alibi_slopes=None,
     num_splits=0,
@@ -1177,6 +1216,7 @@ def flash_attn_with_kvcache(
             Default to 1 / sqrt(headdim).
         causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
         window_size: (left, right). If not (-1, -1), implements sliding window local attention.
+        softcap: float. Anything > 0 activates softcapping attention.
         rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
             If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
             rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
@@ -1226,6 +1266,7 @@ def flash_attn_with_kvcache(
         causal,
         window_size[0],
         window_size[1],
+        softcap,
         rotary_interleaved,
         num_splits,
     )

+ 1 - 0
setup.py

@@ -203,6 +203,7 @@ if not SKIP_CUDA_BUILD:
                         # "-DFLASHATTENTION_DISABLE_BACKWARD",
                         # "-DFLASHATTENTION_DISABLE_DROPOUT",
                         # "-DFLASHATTENTION_DISABLE_ALIBI",
+                        # "-DFLASHATTENTION_DISABLE_SOFTCAP",
                         # "-DFLASHATTENTION_DISABLE_UNEVEN_K",
                         # "-DFLASHATTENTION_DISABLE_LOCAL",
                     ]

+ 27 - 2
tests/test_flash_attn.py

@@ -216,6 +216,7 @@ def attention_ref(
     dropout_mask=None,
     causal=False,
     window_size=(-1, -1),  # -1 means infinite window size
+    softcap=0.0,
     upcast=True,
     reorder_ops=False,
 ):
@@ -253,6 +254,10 @@ def attention_ref(
         scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
     else:
         scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
+    if softcap > 0:
+        scores /= softcap
+        scores = scores.tanh()
+        scores *= softcap
     if key_padding_mask is not None:
         scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
     if window_size[0] >= 0 or window_size[1] >= 0:
@@ -877,8 +882,9 @@ def test_flash_attn_varlen_qkvpacked(
 # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
 @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
 # @pytest.mark.parametrize("dropout_p", [0.17])
+@pytest.mark.parametrize("softcap", [0.0, 50.0])
 def test_flash_attn_output(
-    seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
+    seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap
 ):
     if (
         max(seqlen_q, seqlen_k) >= 2048
@@ -894,6 +900,9 @@ def test_flash_attn_output(
     assert nheads % nheads_k == 0
     window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
     q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    if softcap > 0:
+        # Ensure the values of qk are at least within softcap range.
+        q = q * softcap
     if kvpacked:
         kv = torch.randn(
             batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
@@ -918,6 +927,7 @@ def test_flash_attn_output(
             dropout_p,
             causal=causal,
             window_size=window_size,
+            softcap=softcap,
             alibi_slopes=alibi_slopes,
             deterministic=deterministic,
             return_attn_probs=True,
@@ -930,6 +940,7 @@ def test_flash_attn_output(
             dropout_p,
             causal=causal,
             window_size=window_size,
+            softcap=softcap,
             alibi_slopes=alibi_slopes,
             deterministic=deterministic,
             return_attn_probs=True,
@@ -984,6 +995,7 @@ def test_flash_attn_output(
             dropout_mask,
             causal=causal,
             window_size=window_size,
+            softcap=softcap,
         )
         out_pt, attn_pt = attention_kvpacked_ref(
             q,
@@ -995,6 +1007,7 @@ def test_flash_attn_output(
             dropout_mask,
             causal=causal,
             window_size=window_size,
+            softcap=softcap,
             upcast=False,
             reorder_ops=True,
         )
@@ -1010,6 +1023,7 @@ def test_flash_attn_output(
             dropout_mask,
             causal=causal,
             window_size=window_size,
+            softcap=softcap,
         )
         out_pt, attn_pt = attention_ref(
             q,
@@ -1022,6 +1036,7 @@ def test_flash_attn_output(
             dropout_mask,
             causal=causal,
             window_size=window_size,
+            softcap=softcap,
             upcast=False,
             reorder_ops=True,
         )
@@ -1133,9 +1148,10 @@ def test_flash_attn_output(
 )
 # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
 @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
+@pytest.mark.parametrize("softcap", [0.0, 50.0])
 # @pytest.mark.parametrize('dropout_p', [0.0])
 def test_flash_attn_varlen_output(
-    seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
+    seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap
 ):
     if (
         max(seqlen_q, seqlen_k) >= 2048
@@ -1151,6 +1167,9 @@ def test_flash_attn_varlen_output(
     assert nheads % nheads_k == 0
     window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
     q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    if softcap > 0:
+        # Ensure the values of qk are at least within softcap range.
+        q = q * softcap
 
     if kvpacked:
         kv = torch.randn(
@@ -1199,6 +1218,7 @@ def test_flash_attn_varlen_output(
             dropout_p,
             causal=causal,
             window_size=window_size,
+            softcap=softcap,
             alibi_slopes=alibi_slopes,
             deterministic=deterministic,
             return_attn_probs=True,
@@ -1230,6 +1250,7 @@ def test_flash_attn_varlen_output(
             dropout_p,
             causal=causal,
             window_size=window_size,
+            softcap=softcap,
             alibi_slopes=alibi_slopes,
             deterministic=deterministic,
             return_attn_probs=True,
@@ -1289,6 +1310,7 @@ def test_flash_attn_varlen_output(
             dropout_mask,
             causal=causal,
             window_size=window_size,
+            softcap=softcap,
         )
         out_pt, attn_pt = attention_kvpacked_ref(
             q,
@@ -1300,6 +1322,7 @@ def test_flash_attn_varlen_output(
             dropout_mask,
             causal=causal,
             window_size=window_size,
+            softcap=softcap,
             upcast=False,
             reorder_ops=True,
         )
@@ -1315,6 +1338,7 @@ def test_flash_attn_varlen_output(
             dropout_mask,
             causal=causal,
             window_size=window_size,
+            softcap=softcap,
         )
         out_pt, attn_pt = attention_ref(
             q,
@@ -1327,6 +1351,7 @@ def test_flash_attn_varlen_output(
             dropout_mask,
             causal=causal,
             window_size=window_size,
+            softcap=softcap,
             upcast=False,
             reorder_ops=True,
         )