Browse Source

Clean up softcapping bwd a bit

Tri Dao 8 tháng trước cách đây
mục cha
commit
5ca83a9c71

+ 1 - 1
README.md

@@ -353,7 +353,7 @@ Thanks to @beginlner for this contribution.
 ### 2.6: Softcapping.
 
 Support attention with softcapping, as used in Gemma-2 and Grok models.
-Thanks to @Narsil for this contribution.
+Thanks to @Narsil and @lucidrains for this contribution.
 
 ## Performance
 

+ 5 - 16
csrc/flash_attn/src/flash_bwd_kernel.h

@@ -480,16 +480,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
         // if (cute::thread(32, 0)) { print(scores); }
 
         // Softcapping - calculating dTanh and scaling dS later with it
-        auto dtanh = ([&]{
-            if constexpr (Is_softcap) {
-                Tensor _dtanh = make_tensor_like(scores);
-                flash::calculate_dtanh(scores, _dtanh, params.softcap);
-                return _dtanh;
-            }
-            else {
-                return nullptr;
-            }
-        }());
+        Tensor dtanh = make_tensor_like(scores);
+        if constexpr (Is_softcap) {
+            flash::calculate_dtanh(scores, dtanh, params.softcap);
+        }
 
         // Alibi
         if (Has_alibi) {
@@ -591,13 +585,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
         for (int mi = 0; mi < size<0>(dS); ++mi) {
             #pragma unroll
             for (int ni = 0; ni < size<1>(dS); ++ni) {
-
                 float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
-
-                if constexpr (Is_softcap) {
-                    scaled_ds *= dtanh(mi, ni);
-                }
-
+                if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); }
                 dS(mi, ni) = scaled_ds;
             }
         }

+ 1 - 1
csrc/flash_attn/src/flash_bwd_launch_template.h

@@ -99,7 +99,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
                             // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
                             // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
                             // If Is_local, set Is_causal to false
-                            auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
+                            auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
                             // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
                             if (smem_size_dq_dk_dv >= 48 * 1024)  {
                                 C10_CUDA_CHECK(cudaFuncSetAttribute(