Browse Source

Fp8 kernel with "in-kernel" transpose of V in producer (#1100)

* base version

* restructure pipelines, add special fp8 epilogue

* add variants

* add fp8 causal and modify dynamic tile scheduler

* better causal schedule

* maintain two schedules for non causal and causal

* removing macros

* fix regression

* clean up unneeded methods and variants

* fix mistake with NumProducerThreads

* base version

* restructure pipelines, add special fp8 epilogue

* add variants

* add fp8 causal and modify dynamic tile scheduler

* better causal schedule

* maintain two schedules for non causal and causal

* removing macros

* fix regression

* clean up unneeded methods and variants

* fix mistake with NumProducerThreads

* use seqlen traits

* add fp8 .cu files and benchmark script

* fix merge issue

* fix merge issue

* fix merge issue

* remove duplicate code

* fix regression with varseqlen

* move varseqlen init in constexpr

* fix test script

* more constexpr on varseqlen and add max offset

* add back test cases
jayhshah 7 months ago
parent
commit
5018ac6ac5

+ 333 - 0
hopper/benchmark_flash_attention_fp8.py

@@ -0,0 +1,333 @@
+# Install the newest triton version with
+# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
+import pickle
+import math
+import time
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from einops import rearrange, repeat
+
+from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward
+from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined
+
+from flash_attn import flash_attn_qkvpacked_func
+from flash_attn_interface import flash_attn_func
+
+try:
+    from triton_fused_attention import attention as attention_triton
+except ImportError:
+    attention_triton = None
+
+try:
+    import xformers.ops as xops
+except ImportError:
+    xops = None
+
+try:
+    import cudnn
+except ImportError:
+    cudnn = None
+
+
+def convert_to_cudnn_type(torch_type):
+    if torch_type == torch.float16:
+        return cudnn.data_type.HALF
+    elif torch_type == torch.bfloat16:
+        return cudnn.data_type.BFLOAT16
+    elif torch_type == torch.float32:
+        return cudnn.data_type.FLOAT
+    elif torch_type == torch.int32:
+        return cudnn.data_type.INT32
+    elif torch_type == torch.int64:
+        return cudnn.data_type.INT64
+    elif torch_type == torch.float8_e4m3fn:
+        return cudnn.data_type.FP8_E4M3
+    elif torch_type == torch.float8_e4m3fn:
+        return cudnn.data_type.FP8_E5M2
+    else:
+        raise ValueError("Unsupported tensor data type.")
+
+def cudnn_spda_setup(qkv, seqlen_q, seqlen_k, causal=False):
+    b, _, _, nheads, headdim = qkv.shape
+    assert cudnn is not None, 'CUDNN is not available'
+    o_gpu = torch.zeros(b, seqlen_q, nheads, headdim, dtype=qkv.dtype, device=qkv.device)
+    o_gpu_transposed = torch.as_strided(
+        o_gpu,
+        [b, nheads, seqlen_q, headdim],
+        [nheads * seqlen_q * headdim, headdim, nheads * headdim, 1],
+    )
+    stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=qkv.device)
+    amax_s_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device)
+    amax_o_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device)
+    graph = cudnn.pygraph(
+        io_data_type=convert_to_cudnn_type(qkv.dtype),
+        intermediate_data_type=cudnn.data_type.FLOAT,
+        compute_data_type=cudnn.data_type.FLOAT,
+    )
+    new_q = torch.as_strided(
+        qkv,
+        [b, nheads, seqlen_q, headdim],
+        [seqlen_q * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],
+        storage_offset=0,
+    )
+    q = graph.tensor(
+        name = "Q",
+        dim = list(new_q.shape),
+        stride = list(new_q.stride()),
+        data_type=convert_to_cudnn_type(qkv.dtype)
+    )
+    new_k = torch.as_strided(
+        qkv,
+        [b, nheads, seqlen_k, headdim],
+        [seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],
+        storage_offset=nheads * headdim,
+    )
+    k = graph.tensor(
+        name = "K",
+        dim = list(new_k.shape),
+        stride = list(new_k.stride()),
+        data_type=convert_to_cudnn_type(qkv.dtype)
+    )
+    new_v = torch.as_strided(
+        qkv,
+        [b, nheads, seqlen_k, headdim],
+        [seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1],
+        storage_offset=nheads * headdim * 2,
+    )
+    v = graph.tensor(
+        name = "V",
+        dim = list(new_v.shape),
+        stride = list(new_v.stride()),
+        data_type=convert_to_cudnn_type(qkv.dtype)
+    )
+
+    def get_default_scale_tensor():
+        return graph.tensor(
+            dim = [1, 1, 1, 1],
+            stride = [1, 1, 1, 1],
+            data_type=cudnn.data_type.FLOAT
+        )
+
+    default_scale_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float32, device="cuda")
+    descale_q = get_default_scale_tensor()
+    descale_k = get_default_scale_tensor()
+    descale_v = get_default_scale_tensor()
+    descale_s = get_default_scale_tensor()
+    scale_s = get_default_scale_tensor()
+    scale_o = get_default_scale_tensor()
+
+    o, _, amax_s, amax_o = graph.sdpa_fp8(
+        q=q,
+        k=k,
+        v=v,
+        descale_q=descale_q,
+        descale_k=descale_k,
+        descale_v=descale_v,
+        descale_s=descale_s,
+        scale_s=scale_s,
+        scale_o=scale_o,
+        is_inference=True,
+        attn_scale=1.0 / math.sqrt(headdim),
+        use_causal_mask=causal,
+        name="sdpa",
+    )
+
+    o.set_output(True).set_dim(o_gpu_transposed.shape).set_stride(o_gpu_transposed.stride())
+
+    amax_s.set_output(False).set_dim(amax_s_gpu.shape).set_stride(amax_s_gpu.stride())
+    amax_o.set_output(False).set_dim(amax_o_gpu.shape).set_stride(amax_o_gpu.stride())
+    # stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)
+
+    graph.validate()
+    graph.build_operation_graph()
+    graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
+    graph.check_support()
+    graph.build_plans()
+
+    variant_pack = {
+        q: new_q,
+        k: new_k,
+        v: new_v,
+        descale_q: default_scale_gpu,
+        descale_k: default_scale_gpu,
+        descale_v: default_scale_gpu,
+        descale_s: default_scale_gpu,
+        scale_s: default_scale_gpu,
+        scale_o: default_scale_gpu,
+        o: o_gpu_transposed,
+        amax_s: amax_s_gpu,
+        amax_o: amax_o_gpu,
+    }
+
+    workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
+
+    def run(*args, **kwargs):
+        graph.execute(variant_pack, workspace)
+        return o_gpu, amax_o_gpu
+
+    return run
+
+
+def attention_pytorch(qkv, dropout_p=0.0, causal=True):
+    """
+    Arguments:
+        qkv: (batch_size, seqlen, 3, nheads, head_dim)
+        dropout_p: float
+    Output:
+        output: (batch_size, seqlen, nheads, head_dim)
+    """
+    batch_size, seqlen, _, nheads, d = qkv.shape
+    q, k, v = qkv.unbind(dim=2)
+    q = rearrange(q, 'b t h d -> (b h) t d')
+    k = rearrange(k, 'b s h d -> (b h) d s')
+    softmax_scale = 1.0 / math.sqrt(d)
+    # Preallocate attn_weights for `baddbmm`
+    scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
+    scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
+                       '(b h) t s -> b h t s', h=nheads)
+    if causal:
+        # "triu_tril_cuda_template" not implemented for 'BFloat16'
+        # So we have to construct the mask in float
+        causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
+        # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
+        scores = scores + causal_mask.to(dtype=scores.dtype)
+    attention = torch.softmax(scores, dim=-1)
+    attention_drop = F.dropout(attention, dropout_p)
+    output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
+    return output.to(dtype=qkv.dtype)
+
+def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
+    assert mode in ["fwd", "bwd", "fwd_bwd"]
+    f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
+    return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
+
+def efficiency(flop, time):
+    return (flop / time / 10**12) if not math.isnan(time) else 0.0
+
+def time_fwd(func, *args, **kwargs):
+    time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
+    time_f = benchmark_forward(func, *args, **kwargs)
+    return time_f[1].mean
+
+
+torch.manual_seed(0)
+
+repeats = 30
+device = 'cuda'
+# dtype = torch.float16
+dtype = torch.float8_e4m3fn
+
+bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4224), (2, 8448), (1, 8448 * 2)]
+# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 8192 * 2)]
+# bs_seqlen_vals = [(4, 4096), (2, 8192), (1, 8192 * 2), (4, 4224), (2, 8448), (1, 8448 * 2)]
+# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048)]
+causal_vals = [False, True]
+headdim_vals = [128]
+dim = 2048
+# dim = 256
+dropout_p = 0.0
+
+methods = (["Pytorch", "Flash3", "cuDNN"]        
+        # + (["Triton"] if attention_triton is not None else [])
+        #    + (["xformers.c"] if xops is not None else [])
+        #    + (["xformers.f"] if xops is not None else [])
+           )
+
+time_f = {}
+time_b = {}
+time_f_b = {}
+speed_f = {}
+speed_b = {}
+speed_f_b = {}
+for causal in causal_vals:
+    for headdim in headdim_vals:
+        for batch_size, seqlen in bs_seqlen_vals:
+            torch.cuda.empty_cache()
+            config = (causal, headdim, batch_size, seqlen)
+            nheads = dim // headdim
+            q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=torch.float16, requires_grad=False) for _ in range(3)]
+            
+            qkv = torch.stack([q, k, v], dim=2)
+            qkv = qkv.to(torch.float16)
+            f = time_fwd(attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False)
+            time_f[config, "Pytorch"] = f
+            res_baseline = attention_pytorch(qkv, dropout_p, causal=causal)
+
+            if attention_triton is not None:
+                q_transposed = q.transpose(1, 2).contiguous().to(torch.float8_e4m3fn)
+                k_transposed = k.transpose(1, 2).contiguous().to(torch.float8_e4m3fn)
+                v_transposed = v.transpose(1, 2).contiguous().permute(0, 1, 3, 2).to(torch.float8_e4m3fn)
+                scale = 1 / math.sqrt(headdim)
+                f = time_fwd(
+                    attention_triton, q_transposed, k_transposed, v_transposed,
+                    causal, scale, repeats=5, verbose=False, desc='Triton'
+                )
+                f = time_fwd(
+                    attention_triton, q_transposed, k_transposed, v_transposed,
+                    causal, scale, repeats=repeats, verbose=False, desc='Triton'
+                )
+                time_f[config, "Triton"] = f
+                res = attention_triton(
+                    q_transposed, k_transposed, v_transposed.permute(0, 1, 3, 2),
+                    causal, scale
+                ).half().transpose(1, 2)
+                torch.testing.assert_close(res, res_baseline, atol=0.5, rtol=0.5)
+
+            # out = torch.empty_like(q)
+            q, k, v = q.to(dtype), k.to(dtype), v.to(dtype)                        
+            f = time_fwd(flash_attn_func, q, k, v, causal=causal, repeats=repeats, verbose=False)
+
+            # res = flash_attn_func(q, k, v, causal=causal)
+            # torch.testing.assert_close(res.half(), res_baseline, atol=0.05, rtol=0.05)
+
+            time_f[config, "Flash3"] = f
+
+            if cudnn is not None:
+                qkv_fp8 = qkv.to(dtype)
+                time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
+                f = time_fwd(
+                    cudnn_spda_setup(
+                        qkv_fp8, seqlen, seqlen,
+                        causal=causal
+                    ),
+                    repeats=repeats, verbose=False
+                )
+                time_f[config, "cuDNN"] = f
+                # res, amax_o = cudnn_spda_setup(
+                #     qkv_fp8, seqlen, seqlen,
+                #     causal=causal
+                # )()
+                # res = res.half()
+                # TODO: CUDNN has numerics issues when
+                # num_heads=16, dim=128, seq_len=1024, batch_size=2
+                # or larger sizes.
+                # res_cpu = res.cpu().reshape(-1)
+                # res_baseline_cpu = res_baseline.cpu().reshape(-1)
+                # print(amax_o)
+                # print(res)
+                # print(res_baseline)
+                # for i in range(len(res_cpu)):
+                #     item = res_cpu[i]
+                #     item_baseline = res_baseline_cpu[i]
+                #     if abs(item - item_baseline) > 0.5:
+                #         print(i)
+                #         print(item)
+                #         print(item_baseline)
+                # torch.testing.assert_close(res, res_baseline, atol=0.05, rtol=0.05)
+
+            print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
+            for method in methods:
+                speed_f[config, method] = efficiency(
+                    flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"),
+                    time_f[config, method]
+                )
+                #print (time_f[config,method])
+                print(
+                    f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, {time_f[config, method] * 1e3} ms, "
+                )
+
+
+# with open('flash3_attn_time.plk', 'wb') as fp:
+#     pickle.dump((time_f, time_b, time_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)

+ 80 - 3
hopper/epilogue_fwd_sm90_tma.hpp

@@ -20,7 +20,7 @@ using namespace cute;
 template <typename Ktraits, typename Seqlen_traits>
 struct CollectiveEpilogueFwd {
 
-    using Element = typename Ktraits::Element;
+    using Element = typename Ktraits::OutputType;    
     static constexpr int kBlockM = Ktraits::kBlockM;
     static constexpr int kBlockN = Ktraits::kBlockN;
     static constexpr int kHeadDim = Ktraits::kHeadDim;
@@ -28,7 +28,7 @@ struct CollectiveEpilogueFwd {
 
     static constexpr int kNWarps = Ktraits::kNWarps;
     static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
-    static constexpr bool Is_WS = kNWarps >= 12;
+    static constexpr bool Is_WS = kNWarps >= 12;    
 
     static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
     static constexpr int NumMmaThreads = kNThreads - NumCopyThreads;
@@ -71,6 +71,16 @@ struct CollectiveEpilogueFwd {
         TiledCopyOValLayout{} // Val layout
     ));
 
+    // used for rmem -> smem O copy in fp8 kernel to undo column permutation
+    using ThreadLayoutrO = Layout<Shape<_8, Int<kBlockM/16>, _4, _1>,
+                                 Stride<_4, _32, _1, _0>>;
+    using ValueLayoutrO = Layout<Shape<_1, _2, Shape<_2, _2>, Int<kHeadDim/16>>,
+                                Stride<_0, _2, Stride<_4, _1>, _8>>;
+    using TiledCopyrO = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<uint16_t>, Element>{},
+                      ThreadLayoutrO{}, ValueLayoutrO{}));
+    using TiledCopyShaperO = Shape<_8, Int<kBlockM/8>, _16, Int<kHeadDim/16>>;
+    using SmemLayoutrO = decltype(composition(SmemLayoutO{}, Layout<TiledCopyShaperO>{}));
+
     // Host side kernel arguments
     struct Arguments {
         Element* ptr_O;
@@ -150,7 +160,7 @@ struct CollectiveEpilogueFwd {
         if (get<1>(taccOcO_row(_0{})) == 0) {
             #pragma unroll
             for (int mi = 0; mi < size(lse); ++mi) {
-                const int row = get<0>(taccOcO_row(mi));
+                const int row = get<0>(taccOcO_row(mi));                
                 if (row < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { gLSE(row) = lse(mi); }
             }
         }
@@ -170,6 +180,73 @@ struct CollectiveEpilogueFwd {
         );
     }
 
+    template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
+    CUTLASS_DEVICE void
+    store_fp8(Params const& epilogue_params,
+          FrgTensorO const& tOrO,
+          FrgTensorLSE const& lse,
+          SharedStorage& shared_storage,
+          TiledMma tiled_mma,
+          int thread_idx,
+          cute::tuple<int32_t, int32_t, int32_t> const& block_coord,
+          const Seqlen_traits& seqlen_traits_q
+          ) {
+        // using SmemLayoutrO = typename Ktraits::SmemLayoutrO;
+        // using TiledCopyrO = typename Ktraits::TiledCopyrO;
+        auto [m_block, bidh, bidb] = block_coord;        
+
+        TiledCopyrO rmem_tiled_copy_O;
+        Tensor sOacc = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutrO{});
+        auto rmem_thr_copy_O = rmem_tiled_copy_O.get_thread_slice(thread_idx);
+        
+        Tensor taccOsO = rmem_thr_copy_O.partition_D(sOacc);
+        Tensor tOrO_out = flash::convert_type<Element>(tOrO); // Element is Ktraits::OutputType
+        Tensor taccOrO = make_tensor(tOrO_out.data(), shape(taccOsO));
+
+        // Make sure all WGs have finished reading V
+        cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::ValueEmpty) /*id*/);        
+        cute::copy(rmem_tiled_copy_O, taccOrO, taccOsO);
+        cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
+        cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,
+                                            cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
+        
+        Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE);
+        Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor(
+            mLSE, Shape<Int<kBlockM>>{}, bidh, bidb)(_, m_block);
+        Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));
+        auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
+        Tensor taccOcO = thread_mma.partition_C(caccO);                           // (MMA,MMA_M,MMA_K)
+        static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
+        static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
+        // taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices.
+        Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{});
+        CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row));                     // MMA_M        
+        int const seqlen_q = [&] {
+            if constexpr(Seqlen_traits::kUseVarSeqLen) { return seqlen_traits_q.actual_seq_len; }
+            else { return shape<2>(epilogue_params.layout_LSE); }
+        }();        
+        if (get<1>(taccOcO_row(_0{})) == 0) {
+            #pragma unroll
+            for (int mi = 0; mi < size(lse); ++mi) {
+                const int row = get<0>(taccOcO_row(mi));
+                if (row < seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); }
+            }
+        }
+        
+        int write_warp_idx = kNWarps - 1;
+        if (cutlass::canonical_warp_idx_sync() == write_warp_idx) {
+            cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp,
+                                              cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
+        }
+        TiledCopyO gmem_tiled_copy_O;
+        Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
+        flash::write_O<!Seqlen_traits::kUseVarSeqLen, NumCopyThreads>(
+            epilogue_params.ptr_O, epilogue_params.tma_store_O, gmem_tiled_copy_O, 
+            epilogue_params.layout_O, select<0, 2>(TileShape_MNK{}), sO, 
+            m_block, bidh, bidb, seqlen_traits_q, write_warp_idx
+        );
+    }
+
     CUTLASS_DEVICE void
     store_tail() {
         tma_store_wait<0>();

+ 28 - 14
hopper/flash_api.cpp

@@ -249,7 +249,13 @@ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split
             }
         }
     } else {
-        // run_mha_fwd_<cutlass::float_e4m3_t, 128>(params, stream);
+        if (params.d == 64) {
+            run_mha_fwd_<cutlass::float_e4m3_t, 64>(params, stream);
+        } else if (params.d == 128) {
+            run_mha_fwd_<cutlass::float_e4m3_t, 128>(params, stream);
+        } else if (params.d == 256) {
+            run_mha_fwd_<cutlass::float_e4m3_t, 256>(params, stream);
+        }        
     }
 }
 
@@ -266,12 +272,12 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
     TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
 
     auto q_dtype = q.dtype();
-    TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
-                "FlashAttention only support fp16 and bf16 data type for now");
+    // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
+    //             "FlashAttention only support fp16 and bf16 data type for now");
     // TODO: will add e4m3 later
     // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kFloat8_e4m3fn,
-                // "FlashAttention only support fp16 and bf16 data type");
-                // "FlashAttention only support fp16 and fp8 (e4m3) data type for now");
+    //             "FlashAttention only support fp16 and bf16 data type");
+    //             "FlashAttention only support fp16 and fp8 (e4m3) data type for now");
     TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
     TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
 
@@ -317,13 +323,21 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
     at::Tensor out;
     if (out_.has_value()) {
         out = out_.value();
-        TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
+        // TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
+        TORCH_CHECK(q_dtype == at::ScalarType::Float8_e4m3fn
+                    ? (out.dtype() == at::kHalf)
+                    : (out.dtype() == q_dtype),
+                "Output must have the same dtype as input dtype if dtype is "
+                "not fp8, or fp16 for fp8 input.");
         CHECK_DEVICE(out);
         TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
         CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
         if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
     } else {
-        out = torch::empty_like(q_padded);
+        if (q_dtype == at::ScalarType::Float8_e4m3fn)
+            out = torch::empty_like(q_padded, at::kHalf);
+        else
+            out = torch::empty_like(q_padded);
     }
 
     auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
@@ -534,13 +548,13 @@ void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
     //         run_mha_bwd_<elem_type, kHeadDim>(params, stream);
     //     });
     // });
-    if (params.d == 64) {
-      run_mha_bwd_<cutlass::half_t, 64>(params, stream);
-    } else if (params.d == 128) {
-      run_mha_bwd_<cutlass::half_t, 128>(params, stream);
-    } else {
-      run_mha_bwd_<cutlass::half_t, 256>(params, stream);
-    }
+    // if (params.d == 64) {
+    //   run_mha_bwd_<cutlass::half_t, 64>(params, stream);
+    // } else if (params.d == 128) {
+    //   run_mha_bwd_<cutlass::half_t, 128>(params, stream);
+    // } else {
+    //   run_mha_bwd_<cutlass::half_t, 256>(params, stream);
+    // }
 }
 
 std::vector<at::Tensor>

+ 9 - 0
hopper/flash_fwd_hdim128_e4m3_sm90.cu

@@ -0,0 +1,9 @@
+// Copyright (c) 2024, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::float_e4m3_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim128_fp8<cutlass::float_e4m3_t>(params, stream);
+}

+ 9 - 0
hopper/flash_fwd_hdim256_e4m3_sm90.cu

@@ -0,0 +1,9 @@
+// Copyright (c) 2024, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::float_e4m3_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim256_fp8<cutlass::float_e4m3_t>(params, stream);
+}

+ 9 - 0
hopper/flash_fwd_hdim64_e4m3_sm90.cu

@@ -0,0 +1,9 @@
+// Copyright (c) 2024, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::float_e4m3_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
+    run_mha_fwd_hdim64_fp8<cutlass::float_e4m3_t>(params, stream);
+}

+ 194 - 0
hopper/flash_fwd_kernel.h

@@ -188,4 +188,198 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
 
 }
 
+template <typename Ktraits, bool Is_causal, typename TileScheduler, typename Seqlen_traits>
+__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
+    compute_attn_ws_fp8(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal, Seqlen_traits>::Params const mainloop_params,
+                        CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd<Ktraits, Seqlen_traits>::Params const epilogue_params,
+                        CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params,
+                        Seqlen_traits seqlen_traits_q, Seqlen_traits seqlen_traits_k
+                        ) {
+
+    using Element = typename Ktraits::Element;
+    static_assert(cutlass::sizeof_bits_v<Element> == 8);
+    using ElementAccum = typename Ktraits::ElementAccum;
+    using SoftType = ElementAccum;
+    using TileShape_MNK = typename Ktraits::TileShape_MNK;
+    using ClusterShape = typename Ktraits::ClusterShape_MNK;
+
+    static_assert(Ktraits::Is_WS);
+    static constexpr bool Is_WS = Ktraits::Is_WS;
+    static constexpr bool kUseVarSeqLen = Seqlen_traits::kUseVarSeqLen;
+
+    static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
+    static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
+    static constexpr int kBlockM = Ktraits::kBlockM;
+    // static constexpr int kBlockN = Ktraits::kBlockN;
+    // static constexpr int kHeadDim = Ktraits::kHeadDim;
+    static constexpr bool Delay_V_release = Is_causal && Ktraits::kHeadDim == 128;  
+    // for now, disable for hdim 128 causal to avoid perf regression with register spilling
+    static constexpr bool Use_max_offset = !(Is_causal && Ktraits::kHeadDim == 128);    
+
+    using CollectiveMainloop = CollectiveMainloopFwd<Ktraits, Is_causal, Seqlen_traits>;
+    using CollectiveEpilogue = CollectiveEpilogueFwd<Ktraits, Seqlen_traits>;
+
+    using MainloopPipeline = typename Ktraits::MainloopPipeline;
+    using MainloopPipelineVt = typename Ktraits::MainloopPipelineNoTMA;
+    using PipelineParams = typename MainloopPipeline::Params;
+    using PipelineParamsVt = typename MainloopPipelineVt::Params;
+    using PipelineState = typename MainloopPipeline::PipelineState;
+
+    extern __shared__ char shared_memory[];
+    auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
+
+    int const lane_predicate = cute::elect_one_sync();
+    int const warp_idx = cutlass::canonical_warp_idx_sync();
+
+    // Issue Tma Descriptor Prefetch from a single thread
+    if (warp_idx == 0 && lane_predicate) {
+        CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
+        CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params);
+    }
+
+    // Obtain warp index
+    int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
+
+    // additional pipeline to synchronize out-of-place smem transpose of V
+    PipelineParamsVt pipeline_params_vt;
+    pipeline_params_vt.producer_arv_count = NumCopyThreads;
+    pipeline_params_vt.consumer_arv_count = NumMmaThreads;
+    MainloopPipelineVt pipeline_vt(shared_storage.pipeline_vt, pipeline_params_vt);
+    
+    PipelineParams pipeline_params;
+    pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
+    int warp_group_idx = cutlass::canonical_warp_group_idx();
+    pipeline_params.role = warp_group_idx == 0
+        ? MainloopPipeline::ThreadCategory::Producer
+        : MainloopPipeline::ThreadCategory::Consumer;
+    pipeline_params.is_leader = warp_group_thread_idx == 0;
+    pipeline_params.num_consumers = NumMmaThreads;
+
+    if (warp_idx == 0 && lane_predicate) {
+        shared_storage.barrier_Q.init(1 /*numThreads*/);
+        shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/);
+    }
+    // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
+    MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
+    // pipeline_v has producer warpgroup for its consumer in fp8 kernel
+    pipeline_params.num_consumers = NumCopyThreads;
+    pipeline_params.role = MainloopPipeline::ThreadCategory::ProducerConsumer;
+    MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{});
+
+    CollectiveMainloop collective_mainloop;
+    CollectiveEpilogue collective_epilogue;
+
+    // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
+    if constexpr (size(ClusterShape{}) > 1) {
+        cute::cluster_arrive_relaxed();
+        cute::cluster_wait();
+    } else {
+        __syncthreads();
+    }
+
+    static_assert(Ktraits::kNWarps == 12 || Ktraits::kNWarps == 16);
+    if (warp_group_idx == 0) {  // Producer
+        cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 12 ? 40 : 32>();
+        
+        PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>(); 
+        PipelineState smem_pipe_read, smem_pipe_release;
+
+        int work_idx = 0;
+
+        TileScheduler scheduler(&shared_storage.tile_count_semaphore);
+        for (auto work_tile_info = scheduler.get_initial_work();
+                work_tile_info.is_valid(scheduler_params);
+                work_tile_info = scheduler.template get_next_work</*IsProducer=*/true>(scheduler_params, work_tile_info)) {
+            auto block_coord = work_tile_info.get_block_coord(scheduler_params);
+            auto [m_block, bidh, bidb] = block_coord;
+
+            if constexpr(kUseVarSeqLen) {
+                seqlen_traits_q.init(bidb);
+                seqlen_traits_k.init(bidb);
+                if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) {
+                    continue;
+                }
+            }
+            int n_block_max = collective_mainloop.get_n_block_max(
+                mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
+            if constexpr(Is_causal) {
+                if(n_block_max <= 0) {
+                    scheduler.prefetch_next_work(scheduler_params, work_tile_info);
+                    scheduler.broadcast_next_work(work_tile_info);
+                    // need to sync producer warpgroup
+                    cutlass::arch::NamedBarrier::sync(NumCopyThreads, static_cast<int>(FwdNamedBarriers::ProducerWG) /*id*/);
+                    continue;
+                }
+            }
+            collective_mainloop.load_fp8(
+                mainloop_params, pipeline_k, pipeline_v, pipeline_vt,
+                smem_pipe_write, smem_pipe_read, shared_storage,
+                scheduler, scheduler_params, work_tile_info, block_coord, work_idx,
+                seqlen_traits_q, seqlen_traits_k);
+            ++work_idx;
+            // don't need to sync producer warpgroup here
+            // if constexpr (Is_causal) {
+            //     cutlass::arch::NamedBarrier::sync(NumCopyThreads, static_cast<int>(FwdNamedBarriers::ProducerWG) /*id*/); }
+        }
+        collective_mainloop.load_tail_one_write(pipeline_k, pipeline_v, smem_pipe_write);
+    } else {  // Consumer
+        cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 232 : 160>();        
+
+        TileScheduler scheduler(&shared_storage.tile_count_semaphore);
+        // Initialize matmul objects.
+        typename Ktraits::TiledMma1 tiled_mma1;
+        PipelineState smem_pipe_read;
+        PipelineState smem_pipe_release;
+
+        collective_mainloop.mma_init();
+        scheduler.init_consumer();
+
+        int work_idx = 0;
+        CUTLASS_PRAGMA_NO_UNROLL
+        for (auto work_tile_info = scheduler.get_initial_work();
+             work_tile_info.is_valid(scheduler_params);
+             work_tile_info = scheduler.template get_next_work</*IsProducer=*/false>(scheduler_params, work_tile_info)) {
+            // Attention output (GEMM-II) accumulator.
+            Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
+            flash::Softmax<2 * (2 * kBlockM / NumMmaThreads), Use_max_offset> softmax;
+
+            auto block_coord = work_tile_info.get_block_coord(scheduler_params);
+            auto [m_block, bidh, bidb] = block_coord;
+
+            if constexpr(kUseVarSeqLen) {
+                seqlen_traits_q.init(bidb);
+                seqlen_traits_k.init(bidb);
+                if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) {
+                    continue;
+                }
+            }
+            int n_block_max = collective_mainloop.get_n_block_max(
+                mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
+            if constexpr(Is_causal) {
+                if(n_block_max <= 0) {  // We exit early and write 0 to gO and -inf to gLSE.
+                    collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q);
+                    continue;
+                }
+            }
+            
+            collective_mainloop.mma_fp8<Delay_V_release>(
+                mainloop_params, pipeline_k, pipeline_vt, smem_pipe_read, smem_pipe_release,
+                tOrO, softmax, n_block_max,
+                threadIdx.x - NumCopyThreads, work_idx, m_block,
+                shared_storage, seqlen_traits_q, seqlen_traits_k); 
+
+        #ifndef NO_FP8_COLUMN_PERMUTE
+            collective_epilogue.store_fp8(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
+                                      threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q);
+        #else
+            collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
+                                      threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q);
+        #endif
+            ++work_idx;
+        }
+        collective_epilogue.store_tail();
+    }
+
+}
+
 } // namespace flash

+ 66 - 4
hopper/flash_fwd_launch_template.h

@@ -21,6 +21,7 @@
 template<typename Kernel_traits, bool Is_causal, typename Seqlen_traits>
 void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
     using Element = typename Kernel_traits::Element;
+    using OutputType = typename Kernel_traits::OutputType;
     using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
     using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
 
@@ -32,7 +33,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
         flash::SingleTileScheduler,
         std::conditional_t<!Is_causal,
             flash::StaticPersistentTileScheduler,
-            flash::DynamicPersistentTileScheduler<Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup>
+            flash::DynamicPersistentTileScheduler<Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup, Kernel_traits::NumProducerThreads>
     >>;
     // using Scheduler = flash::SingleTileScheduler;
     Seqlen_traits seqlen_traits_q(
@@ -60,7 +61,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
         });
     typename CollectiveEpilogue::Params epilogue_params =
         CollectiveEpilogue::to_underlying_arguments({
-            static_cast<Element*>(params.o_ptr),
+            static_cast<OutputType*>(params.o_ptr),
             seqlen_traits_q.get_gmem_layout(
                 params.seqlen_q, params.d, params.h, params.b,
                 params.o_row_stride, params.o_head_stride, params.o_batch_stride
@@ -78,12 +79,16 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
 
     // Get the ptr to kernel function.
     void *kernel;
-    kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Scheduler, Seqlen_traits>;
+    if constexpr(cutlass::sizeof_bits_v<Element> == 8)
+        kernel = (void *)flash::compute_attn_ws_fp8<Kernel_traits, Is_causal, Scheduler, Seqlen_traits>;
+    else
+        kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Scheduler, Seqlen_traits>;
     int smem_size = sizeof(typename Kernel_traits::SharedStorage);
     // int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
     // int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
     // int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
-    // printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
+    // int smem_size_o = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_o));
+    // printf("smem_size = %d, q = %d, k = %d, v = %d, o = %d.\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_o);
     if (smem_size >= 48 * 1024) {
        CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
     }
@@ -151,3 +156,60 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
         });
     });
 }
+
+template<typename T>
+void run_mha_fwd_hdim64_fp8(Flash_fwd_params &params, cudaStream_t stream) {
+    constexpr static int Headdim = 64;
+    constexpr static int kBlockM = 192;
+    constexpr static int kBlockN = 128;
+    constexpr static int kNWarps = 4 + kBlockM/16;
+    constexpr static int kStages = 4;    
+    BOOL_SWITCH(params.is_causal, Is_causal, [&] {
+        SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
+            // Only use Cluster if number of tiles along seqlen_q is even
+            BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0 && !Is_causal &&
+                        !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
+                run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
+                              false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);            
+            });
+        });
+    });    
+}
+
+template<typename T>
+void run_mha_fwd_hdim128_fp8(Flash_fwd_params &params, cudaStream_t stream) {
+    constexpr static int Headdim = 128;
+    constexpr static int kBlockM = 128;
+    constexpr static int kBlockN = 256;
+    constexpr static int kNWarps = 4 + kBlockM/16;
+    constexpr static int kStages = 2;    
+    BOOL_SWITCH(params.is_causal, Is_causal, [&] {
+        SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
+            // Only use Cluster if number of tiles along seqlen_q is even
+            BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0 && !Is_causal &&
+                        !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
+                run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
+                              false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
+            });
+        });
+    });    
+}
+
+template<typename T>
+void run_mha_fwd_hdim256_fp8(Flash_fwd_params &params, cudaStream_t stream) {
+    constexpr static int Headdim = 256; 
+    constexpr static int kBlockM = 128;
+    constexpr static int kBlockN = 128;
+    constexpr static int kNWarps = 4 + kBlockM/16;
+    constexpr static int kStages = 2;    
+    BOOL_SWITCH(params.is_causal, Is_causal, [&] {
+        SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
+            // Only use Cluster if number of tiles along seqlen_q is even
+            BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, kBlockM) % 2 == 0 && !Is_causal &&
+                        !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
+                run_flash_fwd<Flash_fwd_kernel_traits_fp8<Headdim, kBlockM, kBlockN, kNWarps, kStages,
+                              false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
+            });
+        });
+    });    
+}

+ 144 - 2
hopper/kernel_traits.h

@@ -33,17 +33,41 @@ struct SharedStorageQKVO {
     };
 };
 
+template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
+          class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
+struct SharedStorageQKVOVt {
+  struct {
+    cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
+    cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
+    cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;  
+    union {
+        cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v_out;
+        cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
+    };
+  };
+  struct {    
+    cutlass::arch::ClusterTransactionBarrier barrier_Q;
+    cutlass::arch::ClusterBarrier barrier_O;
+    typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
+    typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
+    typename cutlass::PipelineAsync<kStages>::SharedStorage pipeline_vt;
+    int tile_count_semaphore;
+  };
+};
+
 // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
 template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool Is_Q_in_regs_=false,
          int kClusterM_ = 1, typename elem_type=cutlass::half_t>
 struct Flash_fwd_kernel_traits {
     using Element = elem_type;
     using ElementAccum = float;
+    using OutputType = elem_type;
     using index_t = int64_t;
 
     // The number of threads.
     static constexpr int kNWarps = kNWarps_;
     static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
+    static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarp;
 
     static constexpr bool Is_Q_in_regs = Is_Q_in_regs_;
     static_assert(kNWarps_ == 4 || kNWarps_ == 8 || kNWarps_ == 12 || kNWarps_ == 16);
@@ -88,9 +112,16 @@ struct Flash_fwd_kernel_traits {
         decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
     using SmemLayoutV =
         decltype(tile_to_shape(SmemLayoutAtomV{},
-                 make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
+                 make_shape(get<1>(TileShape_MNK{}), get<2>(TileShape_MNK{}), Int<kStages>{})));
+
+    // Note this is the transpose in terms of the view, not in terms of memory.
+    using SmemLayoutVt =
+        decltype(composition(SmemLayoutV{},
+                    make_ordered_layout(
+                        make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
+                        Step<_2, _1, _3>{})));
 
-    using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
+    using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, OutputType,
         decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
     using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
 
@@ -100,11 +131,122 @@ struct Flash_fwd_kernel_traits {
                                             SmemLayoutK, SmemLayoutV, SmemLayoutO>;
 
     using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
+    using MainloopPipelineNoTMA = typename cutlass::PipelineAsync<kStages>;
     using PipelineState = typename cutlass::PipelineState<kStages>;
     // using BarrierType = typename MainloopPipeline::ProducerBarrierType;
 
 };
 
+// Traits struct for fp8 kernel with in-kernel transpose
+template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool Is_Q_in_regs_=false,
+         int kClusterM_ = 1, typename elem_type=cutlass::float_e4m3_t>
+struct Flash_fwd_kernel_traits_fp8 {
+    using Element = elem_type;
+    static_assert(cutlass::sizeof_bits_v<Element> == 8);
+    using ElementAccum = float;
+    using OutputType = cutlass::half_t;
+    using index_t = int64_t;      
+
+    // The number of threads.
+    static constexpr int kNWarps = kNWarps_;
+    static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
+    static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
+
+    static constexpr bool Is_Q_in_regs = Is_Q_in_regs_;
+    static_assert(kNWarps_ == 12 || kNWarps_ == 16);
+    static constexpr bool Is_WS = true;    
+    static_assert(!Is_Q_in_regs, "Warp-specialization does not support Q in registers");    
+
+    static constexpr int kBlockM = kBlockM_;
+    static constexpr int kBlockN = kBlockN_;
+    static constexpr int kHeadDim = kHeadDim_;
+    static_assert(kHeadDim % 32 == 0);
+    using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
+
+    static constexpr int kClusterM = kClusterM_;
+    using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
+
+    static constexpr int kStages = kStages_;
+    static_assert(kStages > 1);
+
+    using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;    
+    using TiledMma0 = decltype(cute::make_tiled_mma(
+        cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
+        AtomLayoutMNK{}));
+    
+    using TiledMma1 = decltype(cute::make_tiled_mma(
+        cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{}))>(),
+        AtomLayoutMNK{}));
+
+    using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
+        decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
+    using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
+
+    using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
+        decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
+    using SmemLayoutK =
+        decltype(tile_to_shape(SmemLayoutAtomK{},
+                 make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
+
+    using TransposeShapeAtomV = Shape<_64, _64>;    
+    using SmemLayoutAtomV = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
+    using SmemLayoutV =
+        decltype(tile_to_shape(SmemLayoutAtomV{},
+                 make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
+    
+    // for fp8 in-kernel transpose -- src layout
+    using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{}));
+    using SmemShapeLDSM = Shape<Shape<_8, _8>, Shape<_16, _4>>;
+    using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{},
+        shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{}), shape<3>(SmemLayoutDivideV{})));
+    using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{})));
+
+    // For fp8, this is the memory transpose.
+    using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
+    using SmemLayoutVt =
+        decltype(tile_to_shape(SmemLayoutAtomVt{},
+                 make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{})));
+
+    // for fp8 in-kernel transpose -- dst layout
+    using SmemLayoutVtTrans =
+        decltype(composition(SmemLayoutVt{},
+                             make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1, _3>{})));
+    using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{}));
+#ifndef NO_FP8_COLUMN_PERMUTE
+    using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_8, _8>>;
+#else
+    using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_16, _4>>;
+#endif
+    using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{},
+        shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{}), shape<3>(SmemLayoutDivideVt{})));
+    using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{})));
+
+    using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, OutputType,
+        decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
+    using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
+
+    // used for rmem -> smem O copy in fp8 kernel to undo column permutation
+    using ThreadLayoutrO = Layout<Shape<_8, Int<kBlockM/16>, _4, _1>,
+                                 Stride<_4, _32, _1, _0>>;
+    using ValueLayoutrO = Layout<Shape<_1, _2, Shape<_2, _2>, Int<kHeadDim/16>>,
+                                Stride<_0, _2, Stride<_4, _1>, _8>>;
+    using TiledCopyrO = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<uint16_t>, OutputType>{},
+                      ThreadLayoutrO{}, ValueLayoutrO{}));
+
+    using TiledCopyShaperO = Shape<_8, Int<kBlockM/8>, _16, Int<kHeadDim/16>>;
+    using SmemLayoutrO = decltype(composition(SmemLayoutO{}, Layout<TiledCopyShaperO>{}));
+
+    using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
+
+    using SharedStorage = SharedStorageQKVOVt<kStages, Element, Element, OutputType, SmemLayoutQ,
+                          SmemLayoutK, SmemLayoutV, SmemLayoutO>;
+
+    using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
+    using MainloopPipelineNoTMA = typename cutlass::PipelineAsync<kStages>;
+    using PipelineState = typename cutlass::PipelineState<kStages>;
+    // using BarrierType = typename MainloopPipeline::ProducerBarrierType;
+};
+
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,

+ 567 - 42
hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp

@@ -21,6 +21,64 @@ namespace flash {
 
 using namespace cute;
 
+// 4 warps
+struct SmemTransposeFp8_64x64 {
+
+  using Element = cutlass::float_e4m3_t;
+  
+  using ldsm_thread_shape = Shape<_4, _1, _8, _4>;
+  using ldsm_value_shape = Shape<_2, _8, _2, _1>;  
+  using ldsm_value_stride = Stride<_2, _4, _1, _0>;
+  using TiledCopyLDSM = decltype(make_tiled_copy(
+      Copy_Atom<SM75_U16x8_LDSM_T, Element>{}, Layout<ldsm_thread_shape>{},
+      Layout<ldsm_value_shape, ldsm_value_stride>{}));
+  TiledCopyLDSM tiled_copy_ldsm;  
+
+  using stsm_thread_shape = Shape<_4, _1, _8, _4>;
+  // using stsm_thread_stride = Stride<_1, _0, _4, _32>;
+#ifndef NO_FP8_COLUMN_PERMUTE
+  using stsm_value_shape = Shape<_4, _4, _1, _2>;
+  using stsm_value_stride = Stride<_1, _8, _0, _4>;
+#else
+  using stsm_value_shape = Shape<_4, _4, _2, _1>;
+  using stsm_value_stride = Stride<_1, _8, _4, _0>;
+#endif
+
+  using TiledCopySTSM =
+      decltype(make_tiled_copy(Copy_Atom<SM90_U32x4_STSM_N, Element>{},
+                               Layout<stsm_thread_shape>{},
+                               Layout<stsm_value_shape, stsm_value_stride>{}));
+  TiledCopySTSM tiled_copy_stsm;
+
+  template <class SmemTensor, class SmemTensorOut>
+  CUTLASS_DEVICE void operator()(SmemTensor &&s_in, SmemTensorOut &&s_out) {
+    using namespace cute;
+
+    auto tid = threadIdx.x;
+    auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid);
+    auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid);
+
+    auto tXsX = thr_copy_ldsm.partition_S(s_in);
+    auto tXrX = make_tensor<Element>(shape(tXsX));    
+    auto tXsX_out = thr_copy_stsm.partition_D(s_out);
+
+    cute::copy(tiled_copy_ldsm, tXsX, tXrX);
+
+    auto data = tXrX.data();
+    // size(tXrX) == 32
+    CUTLASS_PRAGMA_UNROLL
+    for (int n = 0; n < size(tXrX); n += 8) {
+      uint32_t *data_32bit = reinterpret_cast<uint32_t *>(&data[n]);
+      auto upper = data_32bit[0];
+      auto lower = data_32bit[1];
+      data_32bit[0] = __byte_perm(upper, lower, 0x6420);
+      data_32bit[1] = __byte_perm(upper, lower, 0x7531);
+    }
+
+    cute::copy(tiled_copy_stsm, tXrX, tXsX_out);
+  }
+};
+
 template <typename Ktraits, bool Is_causal, typename Seqlen_traits>
 struct CollectiveMainloopFwd {
 
@@ -29,40 +87,15 @@ struct CollectiveMainloopFwd {
     using ClusterShape = typename Ktraits::ClusterShape_MNK;
 
     static constexpr int kStages = Ktraits::kStages;
-    static constexpr int kHeadDim = Ktraits::kHeadDim;
+    static constexpr int kHeadDim = Ktraits::kHeadDim;    
 
     using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
     using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));
-
-    using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
-        decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
-    using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
-
-    using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
-        decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
-    using SmemLayoutK =
-        decltype(tile_to_shape(SmemLayoutAtomK{},
-                 make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
-    using SmemLayoutV = SmemLayoutK;
-    // Note this is the transpose in terms of the view, not in terms of memory.
-    using SmemLayoutVt =
-        decltype(cute::composition(SmemLayoutV{},
-                                   make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
-                                               make_stride(get<1>(TileShape_MNK{}), _1{}, Int<size(SmemLayoutV{}(_, _, _0{}))>{}))));
-    // using SmemLayoutAtomVt = cute::GMMA::Layout_MN_SW128_Atom<Element>;
-    // using SmemLayoutVt =
-    //     decltype(tile_to_shape(SmemLayoutAtomVt{},
-    //                            make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{}),
-    //                            Step<_2, _1, _3>{}));  // This gives correct results, without Step it's wrong
-    // using SmemLayoutAtomVt = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::MN, Element,
-    //     decltype(cute::get<2>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
-    // using SmemLayoutVt =
-    //     decltype(tile_to_shape(SmemLayoutAtomVt{},
-    //              make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{})));
-    // using SmemLayoutAtomVTMA = cute::GMMA::Layout_K_SW128_Atom<Element>;
-    // using SmemLayoutVTMA =
-    //     decltype(tile_to_shape(SmemLayoutAtomVTMA{},
-    //                            make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
+    
+    using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
+    using SmemLayoutK = typename Ktraits::SmemLayoutK;
+    using SmemLayoutV = typename Ktraits::SmemLayoutV;
+    using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
 
     using TMA_Q = decltype(make_tma_copy(
         GmemTiledCopyQ{},
@@ -75,7 +108,7 @@ struct CollectiveMainloopFwd {
         select<0, 2>(TileShape_MNK{}),
         _1{}));  // no mcast for Q
 
-    using TMA_KV = decltype(make_tma_copy(
+    using TMA_K = decltype(make_tma_copy(
         GmemTiledCopyKV{},
         make_tensor(
             make_gmem_ptr(static_cast<Element const*>(nullptr)), 
@@ -86,8 +119,21 @@ struct CollectiveMainloopFwd {
         select<1, 2>(TileShape_MNK{}),
         size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
 
+    // TMA_V may differ from TMA_K for fp8 kernel (e.g. swizzling mode)
+    using TMA_V = decltype(make_tma_copy(
+        GmemTiledCopyKV{},
+        make_tensor(
+            make_gmem_ptr(static_cast<Element const*>(nullptr)),
+            repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)),
+            typename Seqlen_traits::StrideT{}
+        ),
+        take<0, 2>(SmemLayoutV{}),
+        select<1, 2>(TileShape_MNK{}),
+        size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
+
     static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
     using MainloopPipeline = typename Ktraits::MainloopPipeline;
+    using MainloopPipelineNoTMA = typename Ktraits::MainloopPipelineNoTMA;
     using PipelineParams = typename MainloopPipeline::Params;
     using PipelineState = typename MainloopPipeline::PipelineState;
 
@@ -95,7 +141,10 @@ struct CollectiveMainloopFwd {
     static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);
     static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);
 
-    static constexpr bool UseSchedulerBarrier = kHeadDim <= 128;
+    // static constexpr bool UseSchedulerBarrier = kHeadDim <= 128;
+    static constexpr bool UseSchedulerBarrier =
+        cutlass::sizeof_bits_v<Element> == 8 ? kHeadDim >= 128
+                                             : kHeadDim <= 128;    
 
     // Host side kernel arguments
     struct Arguments {
@@ -114,8 +163,9 @@ struct CollectiveMainloopFwd {
         typename Seqlen_traits::LayoutT layout_K;
         typename Seqlen_traits::LayoutT layout_V;
         cutlass::FastDivmod qhead_per_khead_divmod;
-        TMA_Q tma_load_Q;
-        TMA_KV tma_load_K, tma_load_V;
+        TMA_Q tma_load_Q;        
+        TMA_K tma_load_K;
+        TMA_V tma_load_V;
         float const softmax_scale_log2;
     };
 
@@ -130,14 +180,14 @@ struct CollectiveMainloopFwd {
             select<0, 2>(TileShape_MNK{}),
             _1{}); // no mcast for Q
         Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K);
-        TMA_KV tma_load_K = make_tma_copy(
+        TMA_K tma_load_K = make_tma_copy(
             GmemTiledCopyKV{},
             mK,
             SmemLayoutK{}(_, _, _0{}),
             select<1, 2>(TileShape_MNK{}),
             size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
         Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V);
-        TMA_KV tma_load_V = make_tma_copy(
+        TMA_V tma_load_V = make_tma_copy(
             GmemTiledCopyKV{},
             mV,
             SmemLayoutV{}(_, _, _0{}),
@@ -164,9 +214,9 @@ struct CollectiveMainloopFwd {
           const Seqlen_traits& seqlen_traits_k
         ) {
         static constexpr int kBlockM = get<0>(TileShape_MNK{});
-        static constexpr int kBlockN = get<1>(TileShape_MNK{});
-        int const seqlen_q = seqlen_traits_q.actual_seq_len;
-        int const seqlen_k = seqlen_traits_k.actual_seq_len;
+        static constexpr int kBlockN = get<1>(TileShape_MNK{});        
+        int const seqlen_q = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_q.actual_seq_len : shape<0>(mainloop_params.layout_Q);
+        int const seqlen_k = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_k.actual_seq_len : shape<0>(mainloop_params.layout_K);        
         int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
         if constexpr (Is_causal) {
             n_block_max = std::min(n_block_max,
@@ -279,13 +329,242 @@ struct CollectiveMainloopFwd {
         scheduler.broadcast_next_work(work_tile_info);
     }
 
+    template <typename Scheduler, typename SharedStorage>
+    CUTLASS_DEVICE void
+    load_fp8(Params const& mainloop_params,
+         MainloopPipeline pipeline_k,
+         MainloopPipeline pipeline_v,
+         MainloopPipelineNoTMA pipeline_vt,         
+         PipelineState& smem_pipe_write,
+         PipelineState& smem_pipe_read,
+         SharedStorage &shared_storage,
+         Scheduler& scheduler,
+         typename Scheduler::Params const& scheduler_params,
+         typename Scheduler::WorkTileInfo& work_tile_info,
+         cute::tuple<int32_t, int32_t, int32_t> block_coord,
+         int work_idx,
+         const Seqlen_traits& seqlen_traits_q,
+         const Seqlen_traits& seqlen_traits_k         
+         ) {
+        
+        using SmemLayoutTransposeV = typename Ktraits::SmemLayoutTransposeV;
+        using SmemLayoutTransposeVt = typename Ktraits::SmemLayoutTransposeVt;
+
+        Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
+        Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
+        Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
+        
+        Tensor sV_divide = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutTransposeV{}));
+        Tensor sVt_divide = as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_v_out.data()), SmemLayoutTransposeVt{}));
+
+        auto smem_transpose_V = SmemTransposeFp8_64x64();
+        auto do_transpose_V = [&](int stage) {
+            CUTLASS_PRAGMA_UNROLL
+            for (int j = 0; j < shape<2>(SmemLayoutTransposeV{}); ++j) {
+                CUTLASS_PRAGMA_UNROLL
+                for (int i = 0; i < shape<1>(SmemLayoutTransposeV{}); ++i) {
+                smem_transpose_V(flatten(sV_divide(_, i, j, stage)),
+                                flatten(sVt_divide(_, i, j, stage)));
+                }
+            }
+        };
+
+        Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
+        Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape());
+        Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape());
+
+        auto [m_block, bidh, bidb] = block_coord;
+        int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
+
+        // Prepare the TMA loads
+        uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
+        constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
+        uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
+        Tensor gQ = seqlen_traits_q.get_local_tile_tensor(
+            mQ, select<0, 2>(TileShape_MNK{}), bidh, bidb)(_, _, m_block);  // (M, K)
+        Tensor gK = seqlen_traits_k.get_local_tile_tensor(
+            mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb);  // (N, K, _)
+        Tensor gV = seqlen_traits_k.get_local_tile_tensor(
+            mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb);  // (N, K, _)
+
+        Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
+        Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
+        auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},
+                                          group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x));  // (TMA), (TMA)
+        auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, block_rank_in_cluster, Layout<ClusterShape>{},
+                                          group_modes<0, 2>(sK), group_modes<0, 2>(gK));  // (TMA, k), (TMA, PIPE)
+        auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, block_rank_in_cluster, Layout<ClusterShape>{},
+                                          group_modes<0, 2>(sV), group_modes<0, 2>(gV));  // (TMA, k), (TMA, PIPE)
+
+        uint16_t mcast_mask_kv = 0;
+        if constexpr (cute::is_same_v<GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST>) {
+            auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
+            for (int m = 0; m < size<0>(block_layout); ++m) {
+                mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{}));
+            }
+        }
+
+        int n_block_max = get_n_block_max(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
+        int n_block = n_block_max - 1;
+
+        int lane_predicate = cute::elect_one_sync();
+        int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
+        if (warp_idx_in_warpgroup == 0 && lane_predicate) {
+            pipeline_k.producer_acquire(smem_pipe_write);
+            copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
+                tKgK(_, n_block), tKsK(_, smem_pipe_write.index()));
+        }
+
+        // Wait for the MMA warpgroups to say that smem_q is ready
+        // for fp8, change from NumThreadsPerWarp to NumThreadsPerWarpGroup
+        cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
+
+        if constexpr(Is_causal) {
+            if (warp_idx_in_warpgroup == 0 && lane_predicate) {
+                shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
+                copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
+                pipeline_v.producer_acquire(smem_pipe_write);
+                copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
+                    tVgV(_, n_block), tVsV(_, smem_pipe_write.index()));
+            }
+
+            shared_storage.barrier_O.wait((work_idx + 1) % 2);            
+                        
+            CUTLASS_PRAGMA_UNROLL
+            for (int iter = 0; iter < kStages && n_block > 0; ++iter, --n_block) {
+                pipeline_v.consumer_wait(smem_pipe_read);
+                // pipeline_vt.producer_acquire(smem_pipe_write);
+                do_transpose_V(smem_pipe_read.index());
+                pipeline_vt.producer_commit(smem_pipe_write);
+                pipeline_v.consumer_release(smem_pipe_read);
+
+                ++smem_pipe_write;
+                ++smem_pipe_read;
+                
+                if (warp_idx_in_warpgroup == 0 && lane_predicate) {
+                    pipeline_k.producer_acquire(smem_pipe_write);
+                    copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
+                        tKgK(_, n_block-1), tKsK(_, smem_pipe_write.index()));
+                    pipeline_v.producer_acquire(smem_pipe_write);
+                    copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
+                        tVgV(_, n_block-1), tVsV(_, smem_pipe_write.index()));
+                }
+            }            
+            
+            #pragma unroll 2
+            for (; n_block > 0; --n_block) {
+                pipeline_v.consumer_wait(smem_pipe_read);
+                pipeline_vt.producer_acquire(smem_pipe_write);
+                do_transpose_V(smem_pipe_read.index());
+                pipeline_vt.producer_commit(smem_pipe_write);
+                pipeline_v.consumer_release(smem_pipe_read);
+
+                ++smem_pipe_write;
+                ++smem_pipe_read;
+                
+                if (warp_idx_in_warpgroup == 0 && lane_predicate) {
+                    pipeline_k.producer_acquire(smem_pipe_write);
+                    copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
+                        tKgK(_, n_block-1), tKsK(_, smem_pipe_write.index()));
+                    pipeline_v.producer_acquire(smem_pipe_write);
+                    copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
+                        tVgV(_, n_block-1), tVsV(_, smem_pipe_write.index()));
+                }                                                                
+            }       
+
+            scheduler.prefetch_next_work(scheduler_params, work_tile_info);
+            scheduler.broadcast_next_work(work_tile_info);
+            
+            pipeline_v.consumer_wait(smem_pipe_read);
+            if (n_block_max > kStages)
+                pipeline_vt.producer_acquire(smem_pipe_write);
+            do_transpose_V(smem_pipe_read.index());
+            pipeline_vt.producer_commit(smem_pipe_write);
+            pipeline_v.consumer_release(smem_pipe_read);
+
+            ++smem_pipe_write;
+            ++smem_pipe_read;
+        } else {
+            if (warp_idx_in_warpgroup == 0 && lane_predicate) {
+                shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
+                copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
+                pipeline_v.producer_acquire(smem_pipe_write);
+                copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
+                    tVgV(_, n_block), tVsV(_, smem_pipe_write.index()));        
+            }
+            // With fp8 kernel, smem_o is in union with smem_v_out,
+            // so could use NamedBarrier instead of ClusterBarrier.
+            // But, this doesn't appear to have any benefit.
+            shared_storage.barrier_O.wait((work_idx + 1) % 2);
+
+            pipeline_v.consumer_wait(smem_pipe_read);
+            // pipeline_vt.producer_acquire(smem_pipe_write);
+            do_transpose_V(smem_pipe_read.index());
+            pipeline_vt.producer_commit(smem_pipe_write);
+            pipeline_v.consumer_release(smem_pipe_read);
+
+            ++smem_pipe_write;
+            ++smem_pipe_read;
+            --n_block;
+
+            constexpr int extra_iterations = kStages - 1;
+            CUTLASS_PRAGMA_UNROLL
+            for (int iter = 0; iter < extra_iterations && n_block >= 0; ++iter) {
+                if (warp_idx_in_warpgroup == 0 && lane_predicate) {
+                    pipeline_k.producer_acquire(smem_pipe_write);
+                    copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
+                        tKgK(_, n_block), tKsK(_, smem_pipe_write.index()));
+                    pipeline_v.producer_acquire(smem_pipe_write);
+                    copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
+                        tVgV(_, n_block), tVsV(_, smem_pipe_write.index()));                
+                }
+                
+                pipeline_v.consumer_wait(smem_pipe_read);
+                // pipeline_vt.producer_acquire(smem_pipe_write);
+                do_transpose_V(smem_pipe_read.index());
+                pipeline_vt.producer_commit(smem_pipe_write);
+                pipeline_v.consumer_release(smem_pipe_read);
+                
+                ++smem_pipe_write;
+                ++smem_pipe_read;
+                --n_block;
+            }
+
+            // CUTLASS_PRAGMA_NO_UNROLL
+            #pragma unroll 2        
+            for (; n_block >= 0; --n_block) {
+                
+                if (warp_idx_in_warpgroup == 0 && lane_predicate) {
+                    pipeline_k.producer_acquire(smem_pipe_write);
+                    copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
+                        tKgK(_, n_block), tKsK(_, smem_pipe_write.index()));
+                    pipeline_v.producer_acquire(smem_pipe_write);
+                    copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write), mcast_mask_kv),
+                        tVgV(_, n_block), tVsV(_, smem_pipe_write.index()));                                
+                }
+                
+                pipeline_v.consumer_wait(smem_pipe_read);
+                pipeline_vt.producer_acquire(smem_pipe_write);
+                do_transpose_V(smem_pipe_read.index());
+                pipeline_vt.producer_commit(smem_pipe_write);
+                pipeline_v.consumer_release(smem_pipe_read);
+                
+                ++smem_pipe_write;
+                ++smem_pipe_read;
+            }
+            // scheduler.prefetch_next_work(scheduler_params, work_tile_info);
+            // scheduler.broadcast_next_work(work_tile_info);
+        }
+    }
+
     /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
     CUTLASS_DEVICE void
     load_tail(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v,
               PipelineState& smem_pipe_write_k, PipelineState& smem_pipe_write_v) {
         int lane_predicate = cute::elect_one_sync();
+        int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
         // Issue the epilogue waits
-        if (lane_predicate) {
+        if (warp_idx_in_warpgroup == 0 && lane_predicate) {
           /* This helps avoid early exit of blocks in Cluster
           * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
           * then would just be acquired since the phase was still inverted from make_producer_start_state
@@ -295,6 +574,23 @@ struct CollectiveMainloopFwd {
         }
     }
 
+    /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
+    CUTLASS_DEVICE void
+    load_tail_one_write(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v,
+              PipelineState& smem_pipe_write) {
+        int lane_predicate = cute::elect_one_sync();
+        int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
+        // Issue the epilogue waits
+        if (warp_idx_in_warpgroup == 0 && lane_predicate) {
+          /* This helps avoid early exit of blocks in Cluster
+          * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
+          * then would just be acquired since the phase was still inverted from make_producer_start_state
+          */
+          pipeline_k.producer_tail(smem_pipe_write);
+          pipeline_v.producer_tail(smem_pipe_write);
+        }
+    }
+
     CUTLASS_DEVICE void
     warp_scheduler_barrier_sync() {
         if constexpr (UseSchedulerBarrier) {
@@ -317,7 +613,7 @@ struct CollectiveMainloopFwd {
     CUTLASS_DEVICE void
     mma_init() {
         // Tell producer (warp 0) that smem_q is ready
-        cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
+        cutlass::arch::NamedBarrier::arrive(NumMmaThreads + Ktraits::NumProducerThreads, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);                
         if constexpr (!UseSchedulerBarrier) { return; }
         static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
         if (cutlass::canonical_warp_group_idx() > 1) {
@@ -387,6 +683,7 @@ struct CollectiveMainloopFwd {
         warp_scheduler_barrier_sync();
         flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
         warp_scheduler_barrier_arrive();
+    
         if (work_idx != 0) {
             int lane_predicate = cute::elect_one_sync();
             if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) {
@@ -495,6 +792,234 @@ struct CollectiveMainloopFwd {
         return;
     }
 
+    template <bool Delay_V_release = false, typename SharedStorage, typename FrgTensorO, typename Softmax>
+    CUTLASS_DEVICE void
+    mma_fp8(Params const& mainloop_params,
+        MainloopPipeline pipeline_k,
+        MainloopPipelineNoTMA pipeline_vt,
+        PipelineState& smem_pipe_read,
+        PipelineState& smem_pipe_release,        
+        FrgTensorO& tOrO,
+        Softmax& softmax,
+        int n_block_count,
+        int thread_idx,
+        int work_idx,
+        int m_block,
+        SharedStorage& shared_storage,
+        const Seqlen_traits& seqlen_traits_q,
+        const Seqlen_traits& seqlen_traits_k
+        ) {
+        static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
+
+        static constexpr int kBlockM = get<0>(TileShape_MNK{});
+        static constexpr int kBlockN = get<1>(TileShape_MNK{});
+
+        Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
+        Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
+        Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v_out.data()), SmemLayoutVt{});
+
+        typename Ktraits::TiledMma0 tiled_mma0;
+        typename Ktraits::TiledMma1 tiled_mma1;
+        auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
+        auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
+
+        // Allocate "fragments/descriptors" for first matmul.
+        Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
+        Tensor tSrK = threadMma0.partition_fragment_B(sK);
+        // Allocate "fragments/descriptors" for second matmul.
+        Tensor tOrV = threadMma1.partition_fragment_B(sVt);
+
+        auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
+            auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
+            pipeline.consumer_wait(smem_pipe_read, barrier_token);
+        };
+
+        tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
+        // workaround for fp8 only perf regression pending change to seqlen traits class
+        int const seqlen_q = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_q.actual_seq_len : shape<0>(mainloop_params.layout_Q);
+        int const seqlen_k = Seqlen_traits::kUseVarSeqLen ? seqlen_traits_k.actual_seq_len : shape<0>(mainloop_params.layout_K);
+        int n_block = n_block_count - 1;
+        
+        cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(work_idx % 2));
+        if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); }
+        
+        Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));        
+        
+        consumer_wait(pipeline_k, smem_pipe_read);                        
+        warp_scheduler_barrier_sync();
+        flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
+        if (work_idx != 0) {        
+            int lane_predicate = cute::elect_one_sync();
+            if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) {
+                tma_store_wait<0>();
+                #pragma unroll
+                for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
+                    shared_storage.barrier_O.arrive(cta_id, lane_predicate);
+                }
+            }        
+        }
+        warpgroup_wait<0>();
+        warp_scheduler_barrier_arrive();
+        pipeline_k.consumer_release(smem_pipe_read);
+
+        auto col_limit_causal = [&](int row, int n_block) {
+            return row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM;
+        };       
+        {
+            Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
+            Tensor tScS = threadMma0.partition_C(cS);
+            #pragma unroll
+            for (int i = 0; i < size(tSrS); ++i) {
+                if constexpr (!Is_causal) {  // Just masking based on col                
+                    if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; }
+                } else {  // mask based on both row and col
+                    if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN,
+                                                         col_limit_causal(int(get<0>(tScS(i))), n_block))) {
+                        tSrS(i) = -INFINITY;
+                    }
+                }
+            }
+        }
+
+        softmax.template online_softmax</*Is_first=*/true>(tSrS, mainloop_params.softmax_scale_log2);
+        Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
+        permute_regs_A_to_C(tOrP);
+        
+        Tensor scores_scale = make_fragment_like(softmax.row_max);
+        clear(scores_scale);
+        
+        consumer_wait(pipeline_vt, smem_pipe_read);
+        flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);                
+        if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }
+
+        ++smem_pipe_read;
+        --n_block;
+        constexpr int extra_iterations = !Is_causal ? kStages - 1 : cute::ceil_div(kBlockM, kBlockN);        
+
+        if constexpr(Is_causal) {
+            CUTLASS_PRAGMA_UNROLL      
+            for (int iter = 0; iter < extra_iterations && n_block >= 0; ++iter, --n_block) {
+                Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
+                consumer_wait(pipeline_k, smem_pipe_read);
+                warp_scheduler_barrier_sync();
+                flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
+
+                Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
+                Tensor tScS = threadMma0.partition_C(cS);
+                #pragma unroll
+                for (int i = 0; i < size(tSrS); ++i) {
+                    if (int(get<1>(tScS(i))) >= col_limit_causal(int(get<0>(tScS(i))), n_block)) {
+                        tSrS(i) = -INFINITY;
+                    }
+                }
+
+                warp_scheduler_barrier_arrive();
+                pipeline_k.consumer_release(smem_pipe_read);
+                consumer_wait(pipeline_vt, smem_pipe_read);
+                
+                cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
+                softmax.rescale_o(tOrO, scores_scale);
+                softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS, mainloop_params.softmax_scale_log2);
+                Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
+                permute_regs_A_to_C(tOrP);
+
+                if constexpr(Delay_V_release) {
+                    pipeline_vt.consumer_release(smem_pipe_release);
+                    ++smem_pipe_release;
+                }
+                flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);            
+                if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }                
+                ++smem_pipe_read;
+            }
+        } else {
+            CUTLASS_PRAGMA_UNROLL      
+            for (int iter = 0; iter < extra_iterations && n_block >= 0; ++iter, --n_block) {
+                Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
+                consumer_wait(pipeline_k, smem_pipe_read);
+                if constexpr(Delay_V_release) {
+                    pipeline_vt.consumer_release(smem_pipe_release);
+                    ++smem_pipe_release;
+                }
+                warp_scheduler_barrier_sync();
+                flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
+                warp_scheduler_barrier_arrive();
+                if constexpr(!Delay_V_release) { pipeline_k.consumer_release(smem_pipe_read); }
+                else { consumer_wait(pipeline_vt, smem_pipe_read); }
+                
+                cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
+                softmax.rescale_o(tOrO, scores_scale);
+                softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
+                Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
+                permute_regs_A_to_C(tOrP);
+
+                if constexpr (Delay_V_release) { pipeline_k.consumer_release(smem_pipe_read); }
+                else { consumer_wait(pipeline_vt, smem_pipe_read); }
+                flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
+                if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }                
+                ++smem_pipe_read;
+            }
+        }
+
+        if constexpr(Delay_V_release) {
+            warp_scheduler_barrier_sync();
+            CUTLASS_PRAGMA_NO_UNROLL
+            for (; n_block >= 0; --n_block) {
+                Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
+                consumer_wait(pipeline_k, smem_pipe_read);
+                pipeline_vt.consumer_release(smem_pipe_release);                
+                flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
+                warp_scheduler_barrier_arrive();
+                warpgroup_wait<0>();                
+                consumer_wait(pipeline_vt, smem_pipe_read);
+
+                cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
+                softmax.rescale_o(tOrO, scores_scale);
+                softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
+                Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
+                permute_regs_A_to_C(tOrP);
+
+                pipeline_k.consumer_release(smem_pipe_read);
+                flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
+                warp_scheduler_barrier_sync();
+                warpgroup_wait<0>();
+                ++smem_pipe_read;
+                ++smem_pipe_release;
+            }
+            warp_scheduler_barrier_arrive();
+            pipeline_vt.consumer_release(smem_pipe_release);
+            ++smem_pipe_release;
+        } else {
+            if constexpr (kHeadDim == 128) { warp_scheduler_barrier_sync(); }
+            CUTLASS_PRAGMA_NO_UNROLL
+            for (; n_block >= 0; --n_block) {
+                Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
+                consumer_wait(pipeline_k, smem_pipe_read);
+                if constexpr (kHeadDim == 256) { warp_scheduler_barrier_sync(); }
+                flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
+                warp_scheduler_barrier_arrive();
+                pipeline_k.consumer_release(smem_pipe_read);
+
+                cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
+                softmax.rescale_o(tOrO, scores_scale);
+                softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
+                Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
+                permute_regs_A_to_C(tOrP);
+
+                consumer_wait(pipeline_vt, smem_pipe_read);
+                if constexpr (kHeadDim == 128) { warp_scheduler_barrier_sync(); }
+                flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
+                pipeline_vt.consumer_release(smem_pipe_read);
+                ++smem_pipe_read;
+            }
+            if constexpr (kHeadDim == 128) { warp_scheduler_barrier_arrive(); }
+        }
+        cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
+        
+        cute::copy(softmax.template finalize</*Check_inf=*/Is_causal>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
+        softmax.rescale_o(tOrO, scores_scale);
+        return;
+    }
+
 };
 
 } // namespace flash

+ 1 - 0
hopper/named_barrier.hpp

@@ -18,6 +18,7 @@ enum class FwdNamedBarriers {
     WarpSchedulerWG1 = 4,
     WarpSchedulerWG2 = 5,
     WarpSchedulerWG3 = 6,
+    ProducerWG = 7
 };
 
 } // flash

+ 6 - 8
hopper/setup.py

@@ -119,7 +119,9 @@ if not SKIP_CUDA_BUILD:
         "flash_bwd_hdim64_fp16_sm90.cu",
         "flash_bwd_hdim128_fp16_sm90.cu",
         "flash_bwd_hdim256_fp16_sm90.cu",
-        # "flash_fwd_hdim128_e4m3_sm90.cu",
+        "flash_fwd_hdim64_e4m3_sm90.cu",
+        "flash_fwd_hdim128_e4m3_sm90.cu",
+        "flash_fwd_hdim256_e4m3_sm90.cu"
     ]
     nvcc_flags = [
         "-O3",
@@ -134,15 +136,11 @@ if not SKIP_CUDA_BUILD:
         "--expt-relaxed-constexpr",
         "--expt-extended-lambda",
         "--use_fast_math",
-        # "--ptxas-options=-v",  # printing out number of registers
+        "--ptxas-options=-v",  # printing out number of registers
         "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage",  # printing out number of registers
         "-lineinfo",
         "-DCUTLASS_DEBUG_TRACE_LEVEL=0",  # Can toggle for debugging
-        "-DNDEBUG",  # Important, otherwise performance is severely impacted
-        "-DQBLKSIZE=128",
-        "-DKBLKSIZE=128",
-        "-DCTA256",
-        "-DDQINRMEM",
+        "-DNDEBUG",  # Important, otherwise performance is severely impacted             
     ]
     include_dirs = [
         # Path(this_dir) / "fmha-pipeline",
@@ -161,7 +159,7 @@ if not SKIP_CUDA_BUILD:
                 "cxx": ["-O3", "-std=c++17"],
                 # "cxx": ["-O0", "-std=c++17"],
                 "nvcc": append_nvcc_threads(
-                    nvcc_flags + ["-DEXECMODE=0"] + cc_flag
+                    nvcc_flags + cc_flag
                 ),
             },
             include_dirs=include_dirs,

+ 17 - 9
hopper/softmax.h

@@ -12,6 +12,8 @@
 
 #include "utils.h"
 
+#include "cutlass/fast_math.h"
+
 namespace flash {
 
 using namespace cute;
@@ -100,8 +102,10 @@ __forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &ten
 }
 
 // Apply the exp to all the elements.
-template <bool Scale_max=true, bool Check_inf=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
+template <bool Scale_max=true, bool Check_inf=true, bool Use_max_offset=false,
+          typename Engine0, typename Layout0, typename Engine1, typename Layout1>
 __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
+    constexpr static float max_offset = Use_max_offset ? 8.0f : 0.0f;
     static_assert(Layout0::rank == 2, "Only support 2D Tensor");
     static_assert(Layout1::rank == 1, "Only support 1D Tensor");
     CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
@@ -111,8 +115,8 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tenso
         // We don't want (-inf - (-inf)) since that would give NaN.
         // If we don't have float around M_LOG2E the multiplication is done in fp64.
         const float max_scaled = Check_inf
-            ? (max(mi) == -INFINITY ? 0.f : (max(mi) * (Scale_max ? scale : float(M_LOG2E))))
-            : (max(mi) * (Scale_max ? scale : float(M_LOG2E)));
+            ? (max(mi) == -INFINITY ? 0.f : (max(mi) * (Scale_max ? scale : float(M_LOG2E))) - max_offset)
+            : (max(mi) * (Scale_max ? scale : float(M_LOG2E)) - max_offset);
         #pragma unroll
         for (int ni = 0; ni < size<1>(tensor); ++ni)  {
             // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
@@ -125,8 +129,11 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tenso
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
-template <int kNRows>
-struct Softmax {
+template <int kNRows, bool Use_max_offset_ = false>
+struct Softmax { 
+    constexpr static bool Use_max_offset = Use_max_offset_; 
+    // constexpr static float max_offset = Use_max_offset ? 8.0f : 0.0f;
+    // constexpr static float max_offset_E = max_offset * float(M_LN2);
 
     using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
     TensorT row_max, row_sum;
@@ -166,7 +173,7 @@ struct Softmax {
         TensorT scores_scale;
         if constexpr (Is_first) {
             flash::template reduce_max</*zero_init=*/true>(scores, row_max);
-            flash::template scale_apply_exp2(scores, row_max, softmax_scale_log2);
+            flash::template scale_apply_exp2</*Scale_max=*/true, /*Check_inf=*/true, Use_max_offset>(scores, row_max, softmax_scale_log2);
             flash::reduce_sum</*zero_init=*/true, /*warp_reduce=*/false>(scores, row_sum);
             cute::fill(scores_scale, 1.f);
             // if (cute::thread0()) { print_tensor(scores); printf("\n scale = %f\n", softmax_scale_log2); print_tensor(row_sum); }
@@ -183,16 +190,17 @@ struct Softmax {
             //     scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
             //     row_sum(mi) *= scores_scale(mi);
             // }
-            flash::template scale_apply_exp2</*Scale_max=*/true, Check_inf>(scores, row_max, softmax_scale_log2);
+            flash::template scale_apply_exp2</*Scale_max=*/true, Check_inf, Use_max_offset>(scores, row_max, softmax_scale_log2);
             // We don't do the reduce across threads here since we don't need to use the row_sum.
             // We do that reduce at the end when we need to normalize the softmax.
             flash::reduce_sum</*zero_init=*/false, /*warp_reduce=*/false>(scores, row_sum);
         }
         return scores_scale;
     };
-
+    
     template<bool Is_dropout=false, bool Split=false, typename Tensor0>
     __forceinline__ __device__ TensorT finalize(Tensor0 &acc_s, float softmax_scale_log2, float rp_dropout=1.0) {
+        constexpr static float max_offset_E = Use_max_offset ? 8.0f * float(M_LN2) : 0.0f;
         // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
         Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
         static_assert(decltype(size<0>(scores))::value == kNRows);
@@ -203,7 +211,7 @@ struct Softmax {
         for (int mi = 0; mi < size(row_max); ++mi) {
             float sum = row_sum(mi);
             float inv_sum = (sum == 0.f || sum != sum) ? 0.f : 1.f / sum;
-            row_sum(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum);
+            row_sum(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : (row_max(mi) * softmax_scale_log2) * float(M_LN2) - max_offset_E + __logf(sum);
             scores_scale(mi) = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
         }
         return scores_scale;

+ 31 - 10
hopper/test_flash_attn.py

@@ -24,9 +24,9 @@ def print_diffs(out, out_ref):
 
 
 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
-# @pytest.mark.parametrize("dtype", [torch.bfloat16])
+# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
 @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
-# @pytest.mark.parametrize("mha_type", ["gqa"])
+# @pytest.mark.parametrize("mha_type", ["mha"])
 @pytest.mark.parametrize("causal", [False, True])
 # @pytest.mark.parametrize("causal", [True])
 # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
@@ -38,6 +38,7 @@ def print_diffs(out, out_ref):
 @pytest.mark.parametrize(
     "seqlen_q,seqlen_k",
     [
+        (1, 1),
         (257, 1),
         (64, 128),
         (128, 128),
@@ -53,28 +54,43 @@ def print_diffs(out, out_ref):
         (1024, 1024),
         (1023, 1024),
         (1024, 1023),
-        (2048, 2048),
+        (4096, 4096),
     ],
 )
 # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
 def test_flash_attn_output(
-    seqlen_q, seqlen_k, d, causal, mha_type, dtype
+    seqlen_q, seqlen_k, d, causal, mha_type, dtype,    
 ):
     device = "cuda"
+    if(dtype == torch.float8_e4m3fn):
+        dtype_init = torch.float16
+    else:
+        dtype_init = dtype    
+    print(dtype)
     # set seed
     torch.random.manual_seed(0)
     # batch_size = 40
     # nheads = 16
-    batch_size = 9
+    batch_size = 4
     nheads = 6
     nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
     # nheads_kv = 2
     # batch_size = 9
     # nheads = 6
-    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
-    k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
-    v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
+    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_init, requires_grad=True)
+    k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
+    v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_init, requires_grad=True)
+
+    q = q.to(dtype)
+    k = k.to(dtype)
+    v = v.to(dtype)
+
     out, lse = flash_attn_func(q, k, v, causal=causal)
+
+    q = q.to(dtype_init)
+    k = k.to(dtype_init)
+    v = v.to(dtype_init)
+    
     out_ref, attn_ref = attention_ref(
         q,
         k,
@@ -105,8 +121,9 @@ def test_flash_attn_output(
     print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
     print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
     print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
+    
     # if not causal:
-    #     print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
+    #     print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")                
     # breakpoint()
 
     # if d <= 128:
@@ -139,7 +156,11 @@ def test_flash_attn_output(
     # Check that FlashAttention's numerical error is at most twice the numerical error
     # of a Pytorch implementation.
     # breakpoint()
-    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
+    if(dtype != torch.float8_e4m3fn):
+        assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
+    else:       
+        # just test correctness of fp8 kernel w/o further quantization techniques
+        assert (out - out_ref).abs().max().item() <= 40 * (out_pt - out_ref).abs().max().item()
 
     # if d <= 128:
     #     assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()

+ 14 - 10
hopper/tile_scheduler.hpp

@@ -164,7 +164,7 @@ public:
 
 };
 
-template<int NumMmaThreads=2 * cutlass::NumThreadsPerWarpGroup>
+template<int NumMmaThreads=2 * cutlass::NumThreadsPerWarpGroup, int NumProducerThreads = cutlass::NumThreadsPerWarp>
 class DynamicPersistentTileScheduler {
 
 protected:
@@ -228,13 +228,13 @@ public:
     CUTLASS_DEVICE
     void
     init_consumer() const {
-        cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
+        cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
     }
 
     CUTLASS_DEVICE
     void
     prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {
-        if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
+        if (threadIdx.x % NumProducerThreads == 0) {
             current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x);
         }
     }
@@ -242,24 +242,28 @@ public:
     CUTLASS_DEVICE
     void
     broadcast_next_work(WorkTileInfo& current_work) const {
-        cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
-        if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
+        cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
+        if (threadIdx.x % NumProducerThreads == 0) {
             *tile_count_smem = current_work.tile_idx;
         }
-        cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
+        cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
     }
 
     template<bool IsProducer=false>
     CUTLASS_DEVICE
     WorkTileInfo
     get_next_work(Params const& params, WorkTileInfo const& current_work) const {
-        if constexpr (IsProducer) {
-            // thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0
+        if constexpr (IsProducer && NumProducerThreads == cutlass::NumThreadsPerWarp) {
+            // thread 0 already has the right tile_idx, just need to broadcast to the rest of the producer threads (warp 0)
             return {__shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/)};
+        } else if constexpr (IsProducer && NumProducerThreads == cutlass::NumThreadsPerWarpGroup) {
+            // TODO: investigate optimal synchronize
+            int tile_idx = *tile_count_smem;
+            return {tile_idx};
         } else {
-            cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
+            cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemFull) /*id*/);
             int tile_idx = *tile_count_smem;
-            cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
+            cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<int>(FwdNamedBarriers::TileCountSmemEmpty) /*id*/);
             return {tile_idx};
         }
     }

+ 32 - 0
hopper/utils.h

@@ -143,6 +143,38 @@ __forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
     }
 };
 
+// Convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N))
+template<typename Layout>
+__forceinline__ __device__ auto convert_layout_acc_Aregs_fp8(Layout acc_layout) {
+    using X = Underscore;    
+    static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
+    static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
+    static_assert(decltype(rank(acc_layout))::value == 3);
+    static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
+    auto l = logical_divide(get<0>(acc_layout), Shape<X, X, _4>{});  // (2, 2, (2, N / 32)))    
+    return make_layout(make_layout(Shape<_4, _2, _2>{}),
+                       get<1>(acc_layout),
+                       make_layout(get<2, 1>(l), get<2>(acc_layout)));
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+// Byte permute for fp8 kernel
+template <typename Fragment>
+CUTLASS_DEVICE void permute_regs_A_to_C(Fragment &accum) {  
+
+  auto data = accum.data();    
+
+  #pragma unroll  
+  for (int n = 0; n < size(accum); n += 8) {
+      uint32_t *data_32bit = reinterpret_cast<uint32_t *>(&data[n]);
+      auto upper = data_32bit[0];
+      auto lower = data_32bit[1];
+      data_32bit[0] = __byte_perm(upper, lower, 0x5410);
+      data_32bit[1] = __byte_perm(upper, lower, 0x7632);        
+  }
+}
+
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 template <typename To_type, typename Engine, typename Layout>