Pārlūkot izejas kodu

FA3 FP8 qkv descales + restore max offset for h128 causal + added sync for producer WG (#1173)

jayhshah 7 mēneši atpakaļ
vecāks
revīzija
c92ca63268

+ 25 - 7
hopper/benchmark_flash_attention_fp8.py

@@ -13,7 +13,7 @@ from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchma
 from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
 
 from flash_attn import flash_attn_qkvpacked_func
-from flash_attn_interface import flash_attn_func
+from flash_attn_interface import flash_attn_func, _flash_attn_forward
 
 try:
     from triton_fused_attention import attention as attention_triton
@@ -219,12 +219,12 @@ device = 'cuda'
 # dtype = torch.float16
 dtype = torch.float8_e4m3fn
 
-bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4224), (2, 8448), (1, 8448 * 2)]
-# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 8192 * 2)]
-# bs_seqlen_vals = [(4, 4096), (2, 8192), (1, 8192 * 2), (4, 4224), (2, 8448), (1, 8448 * 2)]
+# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4224), (2, 8448), (1, 8448 * 2)]
+bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 8192 * 2)]
+# bs_seqlen_vals = [(4, 4096), (2, 8192), (1, 8192 * 2)]
 # bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048)]
 causal_vals = [False, True]
-headdim_vals = [128]
+headdim_vals = [64, 128, 256]
 dim = 2048
 # dim = 256
 dropout_p = 0.0
@@ -276,8 +276,26 @@ for causal in causal_vals:
                 torch.testing.assert_close(res, res_baseline, atol=0.5, rtol=0.5)
 
             # out = torch.empty_like(q)
-            q, k, v = q.to(dtype), k.to(dtype), v.to(dtype)                        
-            f = time_fwd(flash_attn_func, q, k, v, causal=causal, repeats=repeats, verbose=False)
+            q, k, v = q.to(dtype), k.to(dtype), v.to(dtype)
+            softmax_scale = q.shape[-1] ** (-0.5)
+            descale_q = torch.tensor([1.0], dtype=torch.float32, device='cuda')
+            descale_k = torch.tensor([1.0], dtype=torch.float32, device='cuda')
+            descale_v = torch.tensor([1.0], dtype=torch.float32, device='cuda')
+
+            # f = time_fwd(flash_attn_func, q, k, v, causal=causal, repeats=repeats, verbose=False)
+            f = time_fwd(
+                _flash_attn_forward,
+                q, 
+                k, 
+                v, 
+                softmax_scale, 
+                causal=causal, 
+                descale_q=descale_q, 
+                descale_k=descale_k, 
+                descale_v=descale_v, 
+                repeats=repeats, 
+                verbose=False
+            )
 
             # res = flash_attn_func(q, k, v, causal=causal)
             # torch.testing.assert_close(res.half(), res_baseline, atol=0.05, rtol=0.05)

+ 3 - 0
hopper/flash.h

@@ -131,6 +131,9 @@ struct Flash_fwd_params : public Qkv_params {
     bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
 
     int * __restrict__ tile_count_semaphore;
+    float * __restrict__ descale_q_ptr;
+    float * __restrict__ descale_k_ptr;
+    float * __restrict__ descale_v_ptr;
 };
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////

+ 30 - 1
hopper/flash_api.cpp

@@ -105,7 +105,7 @@ void set_params_fprop(Flash_fwd_params &params,
     params.d = d;
     params.d_rounded = d_rounded;
 
-    // Set the different scale values.
+    // Set the different scale values.    
     params.scale_softmax = softmax_scale;
     params.scale_softmax_log2 = softmax_scale * M_LOG2E;
     __half scale_softmax_log2_half = __float2half(params.scale_softmax_log2);
@@ -265,6 +265,9 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
         const at::Tensor &v,         // batch_size x seqlen_k x num_heads_k x head_size
         c10::optional<at::Tensor> &out_,             // batch_size x seqlen_q x num_heads x head_size
         const float softmax_scale,
+        c10::optional<at::Tensor> &descale_q_, // 1
+        c10::optional<at::Tensor> &descale_k_, // 1
+        c10::optional<at::Tensor> &descale_v_, // 1
         bool is_causal) {
 
     auto dprops = at::cuda::getCurrentDeviceProperties();
@@ -372,6 +375,32 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
     auto tile_count_semaphore = is_causal ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
     params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
 
+    if(q_dtype == at::ScalarType::Float8_e4m3fn) {
+        at::Tensor descale_q, descale_k, descale_v;
+        if (descale_q_.has_value() && descale_k_.has_value() && descale_k_.has_value()) {
+            descale_q = descale_q_.value();
+            descale_k = descale_k_.value();
+            descale_v = descale_v_.value();
+            CHECK_DEVICE(descale_q);
+            CHECK_DEVICE(descale_k);
+            CHECK_DEVICE(descale_v);
+            CHECK_SHAPE(descale_q, 1);
+            CHECK_SHAPE(descale_k, 1);
+            CHECK_SHAPE(descale_v, 1);
+        } else {
+            descale_q = torch::ones({1}, opts.dtype(at::kFloat));
+            descale_k = torch::ones({1}, opts.dtype(at::kFloat));
+            descale_v = torch::ones({1}, opts.dtype(at::kFloat));
+        }
+        params.descale_q_ptr = descale_q.data_ptr<float>();
+        params.descale_k_ptr = descale_k.data_ptr<float>();
+        params.descale_v_ptr = descale_v.data_ptr<float>();
+    } else {
+        params.descale_q_ptr = nullptr;
+        params.descale_k_ptr = nullptr;
+        params.descale_v_ptr = nullptr;
+    }
+
     if (seqlen_k > 0) {
         auto stream = at::cuda::getCurrentCUDAStream().stream();
         run_mha_fwd(params, stream);

+ 4 - 1
hopper/flash_attn_interface.py

@@ -14,7 +14,7 @@ import flashattn_hopper_cuda
 def maybe_contiguous(x):
     return x.contiguous() if x is not None and x.stride(-1) != 1 else x
 
-def _flash_attn_forward(q, k, v, softmax_scale, causal):
+def _flash_attn_forward(q, k, v, softmax_scale, causal, descale_q = None, descale_k = None, descale_v = None):
     q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
     out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd(
         q,
@@ -22,6 +22,9 @@ def _flash_attn_forward(q, k, v, softmax_scale, causal):
         v,
         None,
         softmax_scale,
+        descale_q,
+        descale_k,
+        descale_v,
         causal,
     )
     return out, q, k, v, out_padded, softmax_lse, S_dmask

+ 10 - 5
hopper/flash_fwd_kernel.h

@@ -157,7 +157,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
              work_tile_info = scheduler.template get_next_work</*IsProducer=*/false>(scheduler_params, work_tile_info)) {
             // Attention output (GEMM-II) accumulator.
             Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
-            flash::Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax;
+            flash::Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax(mainloop_params.softmax_scale_log2);
 
             auto block_coord = work_tile_info.get_block_coord(scheduler_params);
             auto [m_block, bidh, bidb] = block_coord;
@@ -212,9 +212,8 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
     static constexpr int kBlockM = Ktraits::kBlockM;
     // static constexpr int kBlockN = Ktraits::kBlockN;
     // static constexpr int kHeadDim = Ktraits::kHeadDim;
-    static constexpr bool Delay_V_release = Is_causal && Ktraits::kHeadDim == 128;  
-    // for now, disable for hdim 128 causal to avoid perf regression with register spilling
-    static constexpr bool Use_max_offset = !(Is_causal && Ktraits::kHeadDim == 128);    
+    static constexpr bool Delay_V_release = Is_causal && Ktraits::kHeadDim == 128;    
+    static constexpr bool Use_max_offset = true;
 
     using CollectiveMainloop = CollectiveMainloopFwd<Ktraits, Is_causal, Seqlen_traits>;
     using CollectiveEpilogue = CollectiveEpilogueFwd<Ktraits, Seqlen_traits>;
@@ -269,6 +268,12 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
     CollectiveMainloop collective_mainloop;
     CollectiveEpilogue collective_epilogue;
 
+    float descale_q = *mainloop_params.descale_q_ptr;
+    float descale_k = *mainloop_params.descale_k_ptr;
+    float descale_v = *mainloop_params.descale_v_ptr;
+    shared_storage.softmax_scale_qk_log2 = mainloop_params.softmax_scale_log2 * descale_q * descale_k;
+    shared_storage.descale_v = descale_v;
+
     // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
     if constexpr (size(ClusterShape{}) > 1) {
         cute::cluster_arrive_relaxed();
@@ -341,7 +346,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
              work_tile_info = scheduler.template get_next_work</*IsProducer=*/false>(scheduler_params, work_tile_info)) {
             // Attention output (GEMM-II) accumulator.
             Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
-            flash::Softmax<2 * (2 * kBlockM / NumMmaThreads), Use_max_offset> softmax;
+            flash::Softmax<2 * (2 * kBlockM / NumMmaThreads), Use_max_offset> softmax(shared_storage.softmax_scale_qk_log2);
 
             auto block_coord = work_tile_info.get_block_coord(scheduler_params);
             auto [m_block, bidh, bidb] = block_coord;

+ 63 - 30
hopper/flash_fwd_launch_template.h

@@ -57,7 +57,10 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
                 params.seqlen_k, params.d, params.h_k, params.b, 
                 params.v_row_stride, params.v_head_stride, params.v_batch_stride
             ),  // layout_V
-            params.scale_softmax_log2
+            params.scale_softmax_log2,
+            params.descale_q_ptr,
+            params.descale_k_ptr,
+            params.descale_v_ptr
         });
     typename CollectiveEpilogue::Params epilogue_params =
         CollectiveEpilogue::to_underlying_arguments({
@@ -160,16 +163,26 @@ void run_mha_fwd_hdim64_fp8(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int kBlockN = 128;
     constexpr static int kNWarps = 4 + kBlockM/16;
     constexpr static int kStages = 4;    
-    BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-        SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
-            // Only use Cluster if number of tiles along seqlen_q is even
-            BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0 && !Is_causal &&
-                        !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
-                run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
-                              false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);            
-            });
+    using Seqlen_traits = flash::FixedSeqLenTraits;
+    if(params.is_causal) {
+        run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
+                        false, 1, T>, /*Is_causal=*/true, Seqlen_traits>(params, stream);
+    } else {
+        BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] {
+            run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
+                            false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, Seqlen_traits>(params, stream);
         });
-    });    
+    }
+    // BOOL_SWITCH(params.is_causal, Is_causal, [&] {
+        // SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
+            // Only use Cluster if number of tiles along seqlen_q is even
+            // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0 && !Is_causal &&
+            //             !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
+            //     run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
+            //                   false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);            
+            // });
+        // });
+    // });
 }
 
 template<typename T>
@@ -178,17 +191,27 @@ void run_mha_fwd_hdim128_fp8(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int kBlockM = 128;
     constexpr static int kBlockN = 256;
     constexpr static int kNWarps = 4 + kBlockM/16;
-    constexpr static int kStages = 2;    
-    BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-        SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
-            // Only use Cluster if number of tiles along seqlen_q is even
-            BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0 && !Is_causal &&
-                        !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
-                run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
-                              false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
-            });
+    constexpr static int kStages = 2;
+    using Seqlen_traits = flash::FixedSeqLenTraits;
+    if(params.is_causal) {
+        run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
+                        false, 1, T>, /*Is_causal=*/true, Seqlen_traits>(params, stream);
+    } else {
+        BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] {
+            run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
+                            false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, Seqlen_traits>(params, stream);
         });
-    });    
+    }
+    // BOOL_SWITCH(params.is_causal, Is_causal, [&] {
+        // SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
+            // Only use Cluster if number of tiles along seqlen_q is even
+            // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0 && !Is_causal &&
+            //             !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
+            //     run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
+            //                   false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
+            // });
+        // });
+    // });
 }
 
 template<typename T>
@@ -197,15 +220,25 @@ void run_mha_fwd_hdim256_fp8(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int kBlockM = 128;
     constexpr static int kBlockN = 128;
     constexpr static int kNWarps = 4 + kBlockM/16;
-    constexpr static int kStages = 2;    
-    BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-        SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
-            // Only use Cluster if number of tiles along seqlen_q is even
-            BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0 && !Is_causal &&
-                        !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
-                run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
-                              false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
-            });
+    constexpr static int kStages = 2;
+    using Seqlen_traits = flash::FixedSeqLenTraits;
+    if(params.is_causal) {
+        run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
+                        false, 1, T>, /*Is_causal=*/true, Seqlen_traits>(params, stream);
+    } else {
+        BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0, UseCluster, [&] {
+            run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
+                            false, UseCluster ? 2 : 1, T>, /*Is_causal=*/false, Seqlen_traits>(params, stream);
         });
-    });    
+    }
+    // BOOL_SWITCH(params.is_causal, Is_causal, [&] {
+        // SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
+            // Only use Cluster if number of tiles along seqlen_q is even
+            // BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0 && !Is_causal &&
+            //             !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
+            //     run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
+            //                   false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
+            // });
+        // });
+    // });
 }

+ 2 - 0
hopper/kernel_traits.h

@@ -52,6 +52,8 @@ struct SharedStorageQKVOVt {
     typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
     typename cutlass::PipelineAsync<kStages>::SharedStorage pipeline_vt;
     int tile_count_semaphore;
+    float softmax_scale_qk_log2;
+    float descale_v;
   };
 };
 

+ 40 - 35
hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp

@@ -154,7 +154,10 @@ struct CollectiveMainloopFwd {
         typename Seqlen_traits::LayoutT layout_K;
         Element const* ptr_V;
         typename Seqlen_traits::LayoutT layout_V;
-        float const softmax_scale_log2;
+        float const softmax_scale_log2;        
+        float const* descale_q_ptr;
+        float const* descale_k_ptr;
+        float const* descale_v_ptr;
     };
 
     // Device side kernel params
@@ -166,7 +169,10 @@ struct CollectiveMainloopFwd {
         TMA_Q tma_load_Q;        
         TMA_K tma_load_K;
         TMA_V tma_load_V;
-        float const softmax_scale_log2;
+        float const softmax_scale_log2;        
+        float const* descale_q_ptr;
+        float const* descale_k_ptr;
+        float const* descale_v_ptr;
     };
 
 
@@ -196,7 +202,8 @@ struct CollectiveMainloopFwd {
         return {args.layout_Q, args.layout_K, args.layout_V,
                 cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))),
                 tma_load_Q, tma_load_K, tma_load_V,
-                args.softmax_scale_log2};
+                args.softmax_scale_log2,
+                args.descale_q_ptr, args.descale_k_ptr, args.descale_v_ptr};
     }
 
     /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
@@ -367,6 +374,7 @@ struct CollectiveMainloopFwd {
                                 flatten(sVt_divide(_, i, j, stage)));
                 }
             }
+            cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::ProducerWG) /*id*/);
         };
 
         Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
@@ -720,7 +728,7 @@ struct CollectiveMainloopFwd {
             }
         }
 
-        softmax.template online_softmax</*Is_first=*/true>(tSrS, mainloop_params.softmax_scale_log2);
+        softmax.template online_softmax</*Is_first=*/true>(tSrS);
         Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout()));
         Tensor scores_scale = make_fragment_like(softmax.row_max);
         clear(scores_scale);
@@ -747,8 +755,8 @@ struct CollectiveMainloopFwd {
                     tSrS(i) = -INFINITY;
                 }
             }
-            cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
-            softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS, mainloop_params.softmax_scale_log2);
+            cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS), scores_scale);
+            softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS);
             warpgroup_wait<0>();
             pipeline_v.consumer_release(smem_pipe_read_v);  // release V
             ++smem_pipe_read_k;
@@ -769,8 +777,8 @@ struct CollectiveMainloopFwd {
             warpgroup_wait<1>();
             pipeline_k.consumer_release(smem_pipe_read_k);  // release K
             // auto scores_scale = softmax.template max</*Is_first=*/false>(tSrS);
-            cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
-            softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
+            cute::copy(softmax.template max</*Is_first=*/false>(tSrS), scores_scale);
+            softmax.template online_softmax</*Is_first=*/false>(tSrS);
             warpgroup_wait<0>();
             pipeline_v.consumer_release(smem_pipe_read_v);  // release V
             ++smem_pipe_read_k;
@@ -783,7 +791,7 @@ struct CollectiveMainloopFwd {
         softmax.rescale_o(tOrO, scores_scale);
         consumer_wait(pipeline_v, smem_pipe_read_v);
         flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
-        cute::copy(softmax.template finalize</*Check_inf=*/Is_causal>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
+        cute::copy(softmax.template finalize</*Check_inf=*/Is_causal>(tSrS), scores_scale);
         warpgroup_wait<0>();
         pipeline_v.consumer_release(smem_pipe_read_v);  // release V, otherwise producers will hang
         ++smem_pipe_read_v;
@@ -881,7 +889,7 @@ struct CollectiveMainloopFwd {
             }
         }
 
-        softmax.template online_softmax</*Is_first=*/true>(tSrS, mainloop_params.softmax_scale_log2);
+        softmax.template online_softmax</*Is_first=*/true>(tSrS);
         Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
         permute_regs_A_to_C(tOrP);
         
@@ -915,18 +923,18 @@ struct CollectiveMainloopFwd {
 
                 warp_scheduler_barrier_arrive();
                 pipeline_k.consumer_release(smem_pipe_read);
+                if constexpr(Delay_V_release) {
+                    pipeline_vt.consumer_release(smem_pipe_release);
+                    ++smem_pipe_release;
+                }
                 consumer_wait(pipeline_vt, smem_pipe_read);
                 
-                cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
+                cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS), scores_scale);
                 softmax.rescale_o(tOrO, scores_scale);
-                softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS, mainloop_params.softmax_scale_log2);
+                softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS);
                 Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
                 permute_regs_A_to_C(tOrP);
-
-                if constexpr(Delay_V_release) {
-                    pipeline_vt.consumer_release(smem_pipe_release);
-                    ++smem_pipe_release;
-                }
+                
                 flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);            
                 if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }                
                 ++smem_pipe_read;
@@ -946,9 +954,9 @@ struct CollectiveMainloopFwd {
                 if constexpr(!Delay_V_release) { pipeline_k.consumer_release(smem_pipe_read); }
                 else { consumer_wait(pipeline_vt, smem_pipe_read); }
                 
-                cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
+                cute::copy(softmax.template max</*Is_first=*/false>(tSrS), scores_scale);
                 softmax.rescale_o(tOrO, scores_scale);
-                softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
+                softmax.template online_softmax</*Is_first=*/false>(tSrS);
                 Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
                 permute_regs_A_to_C(tOrP);
 
@@ -965,23 +973,21 @@ struct CollectiveMainloopFwd {
             CUTLASS_PRAGMA_NO_UNROLL
             for (; n_block >= 0; --n_block) {
                 Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
-                consumer_wait(pipeline_k, smem_pipe_read);
-                pipeline_vt.consumer_release(smem_pipe_release);                
-                flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
-                warp_scheduler_barrier_arrive();
-                warpgroup_wait<0>();                
-                consumer_wait(pipeline_vt, smem_pipe_read);
+                consumer_wait(pipeline_k, smem_pipe_read);                
+                flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);                
+                warp_scheduler_barrier_arrive();                
+                pipeline_k.consumer_release(smem_pipe_read);
+                pipeline_vt.consumer_release(smem_pipe_release);
 
-                cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
+                cute::copy(softmax.template max</*Is_first=*/false>(tSrS), scores_scale);
                 softmax.rescale_o(tOrO, scores_scale);
-                softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
+                softmax.template online_softmax</*Is_first=*/false>(tSrS);
                 Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
                 permute_regs_A_to_C(tOrP);
-
-                pipeline_k.consumer_release(smem_pipe_read);
-                flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
+                
+                consumer_wait(pipeline_vt, smem_pipe_read);
+                flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
                 warp_scheduler_barrier_sync();
-                warpgroup_wait<0>();
                 ++smem_pipe_read;
                 ++smem_pipe_release;
             }
@@ -999,9 +1005,9 @@ struct CollectiveMainloopFwd {
                 warp_scheduler_barrier_arrive();
                 pipeline_k.consumer_release(smem_pipe_read);
 
-                cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
+                cute::copy(softmax.template max</*Is_first=*/false>(tSrS), scores_scale);
                 softmax.rescale_o(tOrO, scores_scale);
-                softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
+                softmax.template online_softmax</*Is_first=*/false>(tSrS);
                 Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
                 permute_regs_A_to_C(tOrP);
 
@@ -1014,8 +1020,7 @@ struct CollectiveMainloopFwd {
             if constexpr (kHeadDim == 128) { warp_scheduler_barrier_arrive(); }
         }
         cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
-        
-        cute::copy(softmax.template finalize</*Check_inf=*/Is_causal>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
+        cute::copy(softmax.template finalize</*Check_inf=*/Is_causal>(tSrS, shared_storage.descale_v), scores_scale);
         softmax.rescale_o(tOrO, scores_scale);
         return;
     }

+ 8 - 7
hopper/softmax.h

@@ -137,11 +137,12 @@ struct Softmax {
 
     using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
     TensorT row_max, row_sum;
+    const float softmax_scale_log2;
 
-    CUTLASS_DEVICE Softmax() {};
+    CUTLASS_DEVICE Softmax(float scale_ = 1.f) : softmax_scale_log2(scale_) {};
 
     template<bool Is_first, bool Check_inf=false, typename Tensor0>
-    __forceinline__ __device__ TensorT max(Tensor0 &acc_s, float softmax_scale_log2) {
+    __forceinline__ __device__ TensorT max(Tensor0 &acc_s) {
         // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
         Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
         static_assert(decltype(size<0>(scores))::value == kNRows);
@@ -166,7 +167,7 @@ struct Softmax {
     };
 
     template<bool Is_first, bool Check_inf=false, typename Tensor0>
-    __forceinline__ __device__ TensorT online_softmax(Tensor0 &acc_s, float softmax_scale_log2) {
+    __forceinline__ __device__ TensorT online_softmax(Tensor0 &acc_s) {
         // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
         Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
         static_assert(decltype(size<0>(scores))::value == kNRows);
@@ -197,10 +198,10 @@ struct Softmax {
         }
         return scores_scale;
     };
-    
+
     template<bool Is_dropout=false, bool Split=false, typename Tensor0>
-    __forceinline__ __device__ TensorT finalize(Tensor0 &acc_s, float softmax_scale_log2, float rp_dropout=1.0) {
-        constexpr static float max_offset_E = Use_max_offset ? 8.0f * float(M_LN2) : 0.0f;
+    __forceinline__ __device__ TensorT finalize(Tensor0 &acc_s, float descale_v = 1.f, float rp_dropout=1.f) {
+        constexpr static float max_offset_E = Use_max_offset ? 8.f * float(M_LN2) : 0.f;
         // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
         Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
         static_assert(decltype(size<0>(scores))::value == kNRows);
@@ -210,7 +211,7 @@ struct Softmax {
         #pragma unroll
         for (int mi = 0; mi < size(row_max); ++mi) {
             float sum = row_sum(mi);
-            float inv_sum = (sum == 0.f || sum != sum) ? 0.f : 1.f / sum;
+            float inv_sum = (sum == 0.f || sum != sum) ? 0.f : descale_v / sum;
             row_sum(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : (row_max(mi) * softmax_scale_log2) * float(M_LN2) - max_offset_E + __logf(sum);
             scores_scale(mi) = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
         }

+ 27 - 8
hopper/test_flash_attn.py

@@ -5,7 +5,7 @@ import torch
 import torch.nn.functional as F
 
 from einops import rearrange, repeat
-from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
+from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, _flash_attn_forward
 from tests.test_util import generate_random_padding_mask, generate_qkv, construct_local_mask, attention_ref
 
 ABS_TOL = 5e-3
@@ -37,8 +37,10 @@ def print_diffs(out, out_ref):
 # @pytest.mark.parametrize('d', [56, 80])
 # @pytest.mark.parametrize("d", [64, 128, 256])
 # @pytest.mark.parametrize("d", [64, 96, 128])
-@pytest.mark.parametrize("d", [64, 128])
-# @pytest.mark.parametrize("d", [128])
+# @pytest.mark.parametrize("d", [64, 128])
+@pytest.mark.parametrize("d", [64, 128, 256])
+@pytest.mark.parametrize("descale", [1.0])
+# @pytest.mark.parametrize("descale", [1.0, 2.0, 3.0, 4.0])
 @pytest.mark.parametrize(
     "seqlen_q,seqlen_k",
     [
@@ -63,7 +65,7 @@ def print_diffs(out, out_ref):
 )
 # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
 def test_flash_attn_output(
-    seqlen_q, seqlen_k, d, causal, deterministic, mha_type, dtype,
+    seqlen_q, seqlen_k, d, causal, deterministic, mha_type, dtype, descale
 ):
     device = "cuda"
     if(dtype == torch.float8_e4m3fn):
@@ -89,12 +91,29 @@ def test_flash_attn_output(
     k = k.to(dtype)
     v = v.to(dtype)
 
-    out, lse = flash_attn_func(q, k, v, causal=causal, deterministic=deterministic)
+    softmax_scale = q.shape[-1] ** (-0.5)
+    descale_q = torch.tensor([descale], dtype=torch.float32, device='cuda')
+    descale_k = torch.tensor([descale], dtype=torch.float32, device='cuda')
+    descale_v = torch.tensor([descale], dtype=torch.float32, device='cuda')
+    if(dtype != torch.float8_e4m3fn):
+        out, lse = flash_attn_func(q, k, v, causal=causal, deterministic=deterministic)
+    else:
+        out, q, k, v, out_padded, lse, S_dmask = _flash_attn_forward(
+            q, k, v, softmax_scale, causal, descale_q=descale_q, descale_k=descale_k, descale_v=descale_v
+        )
 
     q = q.to(dtype_init)
     k = k.to(dtype_init)
     v = v.to(dtype_init)
-    
+
+    if(dtype == torch.float8_e4m3fn):
+        descale_q = descale_q.to(dtype_init)
+        descale_k = descale_k.to(dtype_init)
+        descale_v = descale_v.to(dtype_init)
+        q = q * descale_q
+        k = k * descale_k
+        v = v * descale_v
+        
     out_ref, attn_ref = attention_ref(
         q,
         k,
@@ -130,7 +149,7 @@ def test_flash_attn_output(
     #     print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")                
     # breakpoint()
 
-    if d <= 128:
+    if d <= 128 and dtype != torch.float8_e4m3fn:
         g = torch.randn_like(out)
         do_o = (g.float() * out.float()).sum(-1)
         dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
@@ -166,7 +185,7 @@ def test_flash_attn_output(
         # just test correctness of fp8 kernel w/o further quantization techniques
         assert (out - out_ref).abs().max().item() <= 40 * (out_pt - out_ref).abs().max().item()
 
-    if d <= 128:
+    if d <= 128 and dtype != torch.float8_e4m3fn:
         assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 3e-5
         assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 3e-5
         assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 3e-5