Browse Source

Add var-seq-len to FA3 fp16 / bf16 fwd (#1072)

* fwd var-seq-len

* fixes

* benchmark

* fixes

---------

Co-authored-by: Tri Dao <tridao@users.noreply.github.com>
Ying Zhang 8 months ago
parent
commit
dfe1a59e4b

+ 273 - 0
hopper/benchmark_attn.py

@@ -0,0 +1,273 @@
+from functools import partial
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import time
+
+try:
+    import cudnn
+except ImportError:
+    cudnn = None
+
+
+from einops import rearrange, repeat
+
+# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
+from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
+from flash_attn.flash_attn_interface import flash_attn_func
+from flash_attn_interface import flash_attn_func as flash_attn_func_v3, flash_attn_varlen_func as flash_attn_varlen_func_v3
+
+# Need to install triton nightly:
+# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
+
+try:
+    from triton_fused_attention import attention as triton_attention
+except ImportError:
+    triton_attention = None
+
+def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, mode='fwd'):
+    assert mode in ["fwd", "bwd", "fwd_bwd"]
+    f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
+    return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)
+
+
+def convert_to_cudnn_type(torch_type):
+    if torch_type == torch.float16:
+        return cudnn.data_type.HALF
+    elif torch_type == torch.bfloat16:
+        return cudnn.data_type.BFLOAT16
+    elif torch_type == torch.float32:
+        return cudnn.data_type.FLOAT
+    elif torch_type == torch.int32:
+        return cudnn.data_type.INT32
+    elif torch_type == torch.int64:
+        return cudnn.data_type.INT64
+    else:
+        raise ValueError("Unsupported tensor data type.")
+
+
+def cudnn_sdpa_setup(q, k, v, grad, causal=False):
+    b, nheads, seqlen_q, headdim = q.shape
+    _, _, seqlen_k, _ = k.shape
+    assert v.shape == (b, nheads, seqlen_k, headdim)
+    assert cudnn is not None, 'CUDNN is not available'
+    q_gpu, k_gpu, v_gpu = q, k, v
+    o_gpu = torch.empty_like(q_gpu)
+    stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device)
+    graph_forward = cudnn.pygraph(
+        io_data_type=convert_to_cudnn_type(q.dtype),
+        intermediate_data_type=cudnn.data_type.FLOAT,
+        compute_data_type=cudnn.data_type.FLOAT,
+    )
+    q_forward = graph_forward.tensor_like(q_gpu.detach())
+    k_forward = graph_forward.tensor_like(k_gpu.detach())
+    v_forward = graph_forward.tensor_like(v_gpu.detach())
+
+    o_forward, stats_forward = graph_forward.sdpa(
+        name="sdpa",
+        q=q_forward,
+        k=k_forward,
+        v=v_forward,
+        is_inference=False,
+        attn_scale=1.0 / math.sqrt(headdim),
+        use_causal_mask=causal,
+    )
+
+    o_forward.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride())
+    stats_forward.set_output(True).set_data_type(cudnn.data_type.FLOAT)
+
+    graph_forward.validate()
+    graph_forward.build_operation_graph()
+    graph_forward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
+    graph_forward.check_support()
+    graph_forward.build_plans()
+
+    variant_pack_forward = {
+        q_forward: q_gpu,
+        k_forward: k_gpu,
+        v_forward: v_gpu,
+        o_forward: o_gpu,
+        stats_forward: stats_gpu,
+    }
+
+    dQ_gpu = torch.empty_like(q_gpu)
+    dK_gpu = torch.empty_like(k_gpu)
+    dV_gpu = torch.empty_like(v_gpu)
+    dO_gpu = grad
+
+    graph_backward = cudnn.pygraph(
+        io_data_type=cudnn.data_type.HALF,
+        intermediate_data_type=cudnn.data_type.FLOAT,
+        compute_data_type=cudnn.data_type.FLOAT,
+    )
+    
+    q_backward = graph_backward.tensor_like(q_gpu.detach())
+    k_backward = graph_backward.tensor_like(k_gpu.detach())
+    v_backward = graph_backward.tensor_like(v_gpu.detach())
+    o_backward = graph_backward.tensor_like(o_gpu.detach())
+    dO_backward = graph_backward.tensor_like(dO_gpu.detach())
+    stats_backward = graph_backward.tensor_like(stats_gpu.detach())
+    
+    dQ_backward, dK_backward, dV_backward = graph_backward.sdpa_backward(
+        name="sdpa_backward",
+        q=q_backward,
+        k=k_backward,
+        v=v_backward,
+        o=o_backward,
+        dO=dO_backward,
+        stats=stats_backward,
+        attn_scale=1.0 / math.sqrt(headdim),
+        use_causal_mask=causal,
+    )
+    
+    dQ_backward.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride())
+    dK_backward.set_output(True).set_dim(dK_gpu.size()).set_stride(dK_gpu.stride())
+    dV_backward.set_output(True).set_dim(dV_gpu.size()).set_stride(dV_gpu.stride())
+    
+    graph_backward.validate()
+    graph_backward.build_operation_graph()
+    graph_backward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
+    graph_backward.check_support()
+    graph_backward.build_plans()
+
+    variant_pack_backward = {
+        q_backward: q_gpu,
+        k_backward: k_gpu,
+        v_backward: v_gpu,
+        o_backward: o_gpu,
+        dO_backward: dO_gpu,
+        stats_backward: stats_gpu,
+        dQ_backward: dQ_gpu,
+        dK_backward: dK_gpu,
+        dV_backward: dV_gpu,
+    }
+
+    workspace = torch.empty(
+        max(graph_forward.get_workspace_size(), graph_backward.get_workspace_size()), 
+        device="cuda", dtype=torch.uint8
+    )
+
+    def run_fwd(*args, **kwargs):
+        graph_forward.execute(variant_pack_forward, workspace)
+        return o_gpu, stats_gpu
+
+    def run_bwd(*args, **kwargs):
+        graph_backward.execute(variant_pack_backward, workspace)
+        return dQ_gpu, dK_gpu, dV_gpu
+
+    return run_fwd, run_bwd
+
+
+torch.manual_seed(0)
+repeats = 100
+dropout_p = 0.0
+causal = False
+dtype = torch.float16
+device = 'cuda'
+verbose = False
+batch_size = 2
+# seqlen = 2048
+seqlen = 8192
+# seqlen = 4096
+# seqlen = 2047
+dim = 2048
+# headdim = 128
+# headdim = 64
+headdim = 256
+
+# for mode in ['fwd', 'bwd']:
+for mode in ['fwd']:
+    for headdim in [64, 128, 256]:
+    # for headdim in [128]:
+        for seqlen in [1024, 2048, 4096, 8192, 16384, 32768]:
+        # for seqlen in [8192]:
+            nheads = dim // headdim
+            # nheads = 24
+            # headdim = 64
+            # batch_size = 64
+            # seqlen = 512
+            # nheads = 8
+            # headdim = 128
+            nheads_kv = nheads
+    
+            qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
+                            requires_grad=True)
+            q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
+            k = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
+            v = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
+            q_t = q.transpose(1, 2).contiguous().detach().requires_grad_()
+            k_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
+            v_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
+            grad = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype)
+            grad_t = grad.transpose(1, 2).contiguous()
+    
+            bench_fn = benchmark_forward if mode == 'fwd' else partial(benchmark_backward, grad=grad)
+
+            for causal in [False, True]:
+            # for causal in [True]:
+                print(f"\n### {headdim = }, {seqlen = }, {causal = } ###")
+                if headdim <= 128 and cudnn is not None:
+                    cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), causal=causal)
+                f = flops(batch_size, nheads, seqlen, seqlen, headdim, causal=causal, mode=mode)
+                _, m0 = bench_fn(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2')
+                if mode == 'bwd':
+                    ref_dv, v.grad = v.grad.clone(), None
+                    ref_dk, k.grad = k.grad.clone(), None
+                    ref_dq, q.grad = q.grad.clone(), None
+                # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=False)
+                if headdim <= 128:
+                    if triton_attention is not None:
+                        if mode == 'fwd':
+                            time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
+                            _, m3 = benchmark_forward(triton_attention, q_t, k_t, v_t, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton')
+                        # TODO: fix Triton numeric errors.
+                        # if mode == 'bwd':
+                        #     dv, v_t.grad = v_t.grad.clone(), None
+                        #     dk, k_t.grad = k_t.grad.clone(), None
+                        #     dq, q_t.grad = q_t.grad.clone(), None
+                        #     torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
+                        #     torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
+                        #     torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
+                    if cudnn is not None:
+                        time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
+                        if mode == 'fwd':
+                            _, m2 = benchmark_forward(cudnn_sdpa_fwd, repeats=repeats, verbose=verbose, desc='CuDNN')
+                        else:
+                            cudnn_sdpa_fwd()
+                            _, m2 = benchmark_forward(cudnn_sdpa_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')
+                            dq, dk, dv = cudnn_sdpa_bwd()
+                            torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
+                            torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
+                            torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
+                        # pytorch_profiler(cudnn_sdpa, backward=False)
+                if headdim == 128 or mode == 'fwd':
+                    time.sleep(1)
+                    _, m1 = bench_fn(flash_attn_func_v3, q, k, v, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3')
+                    q_var = q.reshape(-1, q.shape[-2], q.shape[-1])
+                    k_var = k.reshape(-1, k.shape[-2], k.shape[-1])
+                    v_var = v.reshape(-1, v.shape[-2], v.shape[-1])
+                    lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32)
+                    cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(lens, dim=0, dtype=torch.int32)]).cuda()
+                    time.sleep(1)
+                    _, m1_var = bench_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len')
+                    if mode == 'bwd':
+                        dv, v.grad = v.grad.clone(), None
+                        dk, k.grad = k.grad.clone(), None
+                        dq, q.grad = q.grad.clone(), None
+                        torch.testing.assert_close(ref_dv, dv, atol=0.05, rtol=0.05)
+                        torch.testing.assert_close(ref_dk, dk, atol=0.05, rtol=0.05)
+                        torch.testing.assert_close(ref_dq, dq, atol=0.05, rtol=0.05)
+ 
+                # pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, backward=False)
+                print(f'Fav2: {m0.mean * 1e3:.3f}ms, {(f / m0.mean * 1e-12):.1f} TFLOPS')
+                if headdim <= 128:
+                    if triton_attention is not None:
+                        print(f'Triton: {m3.mean * 1e3:.3f}ms, {(f / m3.mean * 1e-12):.1f} TFLOPS')
+                    if cudnn is not None:
+                        print(f'CuDNN: {m2.mean * 1e3:.3f}ms, {(f / m2.mean * 1e-12):.1f} TFLOPS')
+                if headdim == 128 or mode == 'fwd':
+                    print(f'Fav3: {m1.mean * 1e3:.3f}ms, {(f / m1.mean * 1e-12):.1f} TFLOPS')
+                    print(f'Fav3 varlen: {m1_var.mean * 1e3:.3f}ms, {(f / m1_var.mean * 1e-12):.1f} TFLOPS')
+    

+ 72 - 68
hopper/epilogue_fwd_sm90_tma.hpp

@@ -17,20 +17,15 @@ namespace flash {
 using namespace cute;
 
 // template <int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename Element_>
-template <typename Ktraits>
+template <typename Ktraits, typename Seqlen_traits>
 struct CollectiveEpilogueFwd {
 
     using Element = typename Ktraits::Element;
     static constexpr int kBlockM = Ktraits::kBlockM;
     static constexpr int kBlockN = Ktraits::kBlockN;
     static constexpr int kHeadDim = Ktraits::kHeadDim;
-    // using Element = Element_;
-    // static constexpr int kBlockM = kBlockM_;
-    // static constexpr int kBlockN = kBlockN_;
-    // static constexpr int kHeadDim = kHeadDim_;
     using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
 
-    // static constexpr int kNWarps = kNWarps_;
     static constexpr int kNWarps = Ktraits::kNWarps;
     static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
     static constexpr bool Is_WS = kNWarps >= 12;
@@ -38,20 +33,6 @@ struct CollectiveEpilogueFwd {
     static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
     static constexpr int NumMmaThreads = kNThreads - NumCopyThreads;
 
-    using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
-
-    // These are for storing the output tensor without TMA (e.g., for setting output to zero)
-    static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
-    static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
-    static constexpr int kGmemThreadsPerRow = kHeadDim / kGmemElemsPerLoad;
-    static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow");
-    using GmemLayoutAtom = Layout<Shape <Int<NumMmaThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
-                                  Stride<Int<kGmemThreadsPerRow>, _1>>;
-    using GmemTiledCopyO = decltype(
-        make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
-                        GmemLayoutAtom{},
-                        Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{}));  // Val layout, 8 or 16 vals per store
-
     using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
         decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
     using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
@@ -59,52 +40,72 @@ struct CollectiveEpilogueFwd {
     using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
     using SharedStorage = cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>>;
 
-    using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>;  // (seqlen_q, d, head, batch)
-    using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
-    using StrideLSE = cute::Stride<_1, int64_t, int64_t>;            // (seqlen_q, head, batch)
-
+    using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
     using TMA_O = decltype(make_tma_copy(
         GmemTiledCopyOTMA{},
-        make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), repeat_like(StrideO{}, int32_t(0)), StrideO{}),
+        make_tensor(
+            make_gmem_ptr(static_cast<Element*>(nullptr)), 
+            typename Seqlen_traits::ShapeT{}, 
+            typename Seqlen_traits::StrideT{}
+        ),
         SmemLayoutO{},
         select<0, 2>(TileShape_MNK{}),
         _1{}));  // no mcast for O
 
+    // These are for storing the output tensor without TMA (e.g., for setting output to zero and var-seq-len)
+    static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<Element>);
+    static_assert(kHeadDim % kNumVecElem == 0);
+    static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem;
+    static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
+    static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
+    using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, Element>;
+    using TiledCopyOThrLayout = decltype(cute::make_layout(
+        cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
+        LayoutRight{}));
+    using TiledCopyOValLayout = decltype(cute::make_layout(
+        cute::make_shape(_1{}, Int<kNumVecElem>{}),
+        LayoutRight{}));
+    using TiledCopyO = decltype(make_tiled_copy(
+        TiledCopyOAtom{},
+        TiledCopyOThrLayout{}, // Thr layout
+        TiledCopyOValLayout{} // Val layout
+    ));
+
     // Host side kernel arguments
     struct Arguments {
         Element* ptr_O;
-        ShapeO const shape_O;
-        StrideO const stride_O;
+        typename Seqlen_traits::LayoutT const layout_O;
         float* ptr_LSE;
-        StrideLSE const stride_LSE;
+        typename Seqlen_traits::LayoutLseT const layout_LSE;
     };
 
     // Device side kernel params
     struct Params {
         Element* ptr_O;
-        ShapeO const shape_O;
-        StrideO const stride_O;
+        typename Seqlen_traits::LayoutT const layout_O;
         float* ptr_LSE;
-        StrideLSE const stride_LSE;
+        typename Seqlen_traits::LayoutLseT const layout_LSE;
         TMA_O tma_store_O;
     };
 
     static Params
     to_underlying_arguments(Arguments const& args) {
-        Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O);
+        Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.layout_O);
         TMA_O tma_store_O = make_tma_copy(
             GmemTiledCopyOTMA{},
             mO,
             SmemLayoutO{},
             select<0, 2>(TileShape_MNK{}),
             _1{}); // no mcast for O
-        return {args.ptr_O, args.shape_O, args.stride_O, args.ptr_LSE, args.stride_LSE, tma_store_O};
+        return {args.ptr_O, args.layout_O, args.ptr_LSE, args.layout_LSE, tma_store_O};
     }
 
     /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
     CUTLASS_DEVICE
     static void prefetch_tma_descriptors(Params const& epilogue_params) {
-        cute::prefetch_tma_descriptor(epilogue_params.tma_store_O.get_tma_descriptor());
+        if constexpr (!Seqlen_traits::kUseVarSeqLen) {
+            cute::prefetch_tma_descriptor(epilogue_params.tma_store_O.get_tma_descriptor());
+        }
     }
 
     template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
@@ -115,7 +116,8 @@ struct CollectiveEpilogueFwd {
           SharedStorage& shared_storage,
           TiledMma tiled_mma,
           int thread_idx,
-          cute::tuple<int32_t, int32_t, int32_t> const& block_coord
+          cute::tuple<int32_t, int32_t, int32_t> const& block_coord,
+          const Seqlen_traits& seqlen_traits_q
           ) {
 
         auto [m_block, bidh, bidb] = block_coord;
@@ -134,16 +136,9 @@ struct CollectiveEpilogueFwd {
         cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,
                                             cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
 
-        Tensor mO = epilogue_params.tma_store_O.get_tma_tensor(epilogue_params.shape_O);
-        Tensor gO = local_tile(mO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{}));  // (M, K)
-        auto block_tma_O = epilogue_params.tma_store_O.get_slice(_0{});
-        Tensor tOgO = block_tma_O.partition_D(gO);  // (TMA, TMA_M, TMA_K)
-        Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K)
-
-        auto shape_LSE = select<0, 2, 3>(epilogue_params.shape_O);
-        Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), shape_LSE, epilogue_params.stride_LSE);
-        Tensor gLSE = local_tile(mLSE(_, bidh, bidb), Shape<Int<kBlockM>>{}, make_coord(m_block));
-
+        Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE);
+        Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor(
+            mLSE, Shape<Int<kBlockM>>{}, bidh, bidb)(_, m_block);
         Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));
         auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
         Tensor taccOcO = thread_mma.partition_C(caccO);                           // (MMA,MMA_M,MMA_K)
@@ -156,19 +151,23 @@ struct CollectiveEpilogueFwd {
             #pragma unroll
             for (int mi = 0; mi < size(lse); ++mi) {
                 const int row = get<0>(taccOcO_row(mi));
-                if (row < get<0>(shape_LSE) - m_block * kBlockM) { gLSE(row) = lse(mi); }
+                if (row < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { gLSE(row) = lse(mi); }
             }
         }
 
-        if (cutlass::canonical_warp_idx_sync() == kNWarps - 1) {
-            cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp,
-                                              cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
-            int const lane_predicate = cute::elect_one_sync();
-            if (lane_predicate) {
-                cute::copy(epilogue_params.tma_store_O, tOsO, tOgO);
-                tma_store_arrive();
-            }
+        int write_warp_idx = kNWarps - 1;
+        if (cutlass::canonical_warp_idx_sync() == write_warp_idx) {
+            cutlass::arch::NamedBarrier::sync(
+                NumMmaThreads + cutlass::NumThreadsPerWarp, 
+                cutlass::arch::ReservedNamedBarriers::EpilogueBarrier
+            );
         }
+        TiledCopyO gmem_tiled_copy_O;
+        flash::write_O<!Seqlen_traits::kUseVarSeqLen, NumCopyThreads>(
+            epilogue_params.ptr_O, epilogue_params.tma_store_O, gmem_tiled_copy_O, 
+            epilogue_params.layout_O, select<0, 2>(TileShape_MNK{}), sO, 
+            m_block, bidh, bidb, seqlen_traits_q, write_warp_idx
+        );
     }
 
     CUTLASS_DEVICE void
@@ -177,20 +176,25 @@ struct CollectiveEpilogueFwd {
     }
 
     // Write 0 to output and -inf to LSE
+    template<typename SharedStorage>
     CUTLASS_DEVICE void
     store_zero(
-         Params const& epilogue_params,
-         int thread_idx,
-         cute::tuple<int32_t, int32_t, int32_t> const& block_coord
-         ) {
+          Params const& epilogue_params,
+          SharedStorage& shared_storage,
+          int thread_idx,
+          cute::tuple<int32_t, int32_t, int32_t> const& block_coord,
+          const Seqlen_traits& seqlen_traits_q
+          ) {
         auto [m_block, bidh, bidb] = block_coord;
-        Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.shape_O, epilogue_params.stride_O);
-        Tensor gO = local_tile(mO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{}));  // (M, K)
-        auto shape_LSE = select<0, 2, 3>(epilogue_params.shape_O);
-        Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), shape_LSE, epilogue_params.stride_LSE);
-        Tensor gLSE = local_tile(mLSE(_, bidh, bidb), Shape<Int<kBlockM>>{}, make_coord(m_block));
-
-        GmemTiledCopyO gmem_tiled_copy_O;
+        Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.layout_O);
+        Tensor gO = seqlen_traits_q.get_local_tile_tensor(
+            mO, select<0, 2>(TileShape_MNK{}), bidh, bidb
+        )(_, _, m_block);  // (M, K)
+        Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE);
+        Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor(
+            mLSE, Shape<Int<kBlockM>>{}, bidh, bidb)(_, m_block);
+
+        TiledCopyO gmem_tiled_copy_O;
         auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
         Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
         Tensor tOrO = make_fragment_like(tOgO);
@@ -201,13 +205,13 @@ struct CollectiveEpilogueFwd {
         Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
         Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
         #pragma unroll
-        for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.shape_O); }
+        for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.layout_O.shape()); }
         // Clear_OOB_K must be false since we don't want to write zeros to gmem
         flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
-            gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.shape_O) - m_block * kBlockM
+            gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.layout_O.shape()) - m_block * kBlockM
         );
         static_assert(kBlockM <= NumMmaThreads);
-        if (thread_idx < get<0>(shape_LSE) - m_block * kBlockM) { gLSE(thread_idx) = INFINITY; }
+        if (thread_idx < get<0>(epilogue_params.layout_LSE.shape()) - m_block * kBlockM) { gLSE(thread_idx) = INFINITY; }
     }
 
 };

+ 3 - 1
hopper/flash.h

@@ -57,7 +57,7 @@ struct Flash_fwd_params : public Qkv_params {
     void * __restrict__ softmax_lseaccum_ptr;
 
     // The dimensions.
-    int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
+    int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q, total_k;
 
     // The scaling factors for the kernel.
     float scale_softmax;
@@ -128,6 +128,8 @@ struct Flash_fwd_params : public Qkv_params {
     void * __restrict__ alibi_slopes_ptr;
     index_t alibi_slopes_batch_stride;
 
+    bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
+
     int * __restrict__ tile_count_semaphore;
 };
 

+ 158 - 1
hopper/flash_api.cpp

@@ -43,7 +43,8 @@ void set_params_fprop(Flash_fwd_params &params,
                       float softmax_scale,
                       int window_size_left,
                       int window_size_right,
-                      bool seqlenq_ngroups_swapped=false) {
+                      bool seqlenq_ngroups_swapped=false,
+                      bool unpadded_lse=false) {
 
     // Reset the parameters
     params = {};
@@ -81,6 +82,11 @@ void set_params_fprop(Flash_fwd_params &params,
     params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
     params.seqused_k = static_cast<int *>(seqused_k);
 
+    TORCH_CHECK(
+        bool(params.cu_seqlens_q) == bool(params.cu_seqlens_k),
+        "cu_seqlens_q and cu_seqlens_k must be both null or non-null"
+    );
+
     // P = softmax(QK^T)
     params.p_ptr = p_d;
 
@@ -139,6 +145,8 @@ void set_params_fprop(Flash_fwd_params &params,
     #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
         TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
     #endif
+
+    params.unpadded_lse = unpadded_lse;
 }
 
 void set_params_dgrad(Flash_bwd_params &params,
@@ -372,6 +380,154 @@ mha_fwd(at::Tensor &q,         // batch_size x seqlen_q x num_heads x head_size
     return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
 }
 
+std::vector<at::Tensor>
+mha_varlen_fwd(at::Tensor &q,  // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
+               const at::Tensor &k,  // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
+               const at::Tensor &v,  // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
+               c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
+               const at::Tensor &cu_seqlens_q,  // b+1
+               const at::Tensor &cu_seqlens_k,  // b+1
+               c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
+               int max_seqlen_q,
+               const int max_seqlen_k,
+               const float softmax_scale,
+               bool is_causal) {
+
+    auto dprops = at::cuda::getCurrentDeviceProperties();
+    bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
+    TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
+
+    auto q_dtype = q.dtype();
+    TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
+                "FlashAttention only support fp16 and bf16 data type");
+    TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
+    TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
+    TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
+    TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
+
+    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
+    CHECK_DEVICE(cu_seqlens_q);
+    CHECK_DEVICE(cu_seqlens_k);
+
+    TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
+    TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
+    TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
+    CHECK_CONTIGUOUS(cu_seqlens_q);
+    CHECK_CONTIGUOUS(cu_seqlens_k);
+
+    const auto sizes = q.sizes();
+
+    const int batch_size = cu_seqlens_q.numel() - 1;
+    int num_heads = sizes[1];
+    const int head_size_og = sizes[2];
+    const int num_heads_k = k.size(1);
+
+    int window_size_left = -1;
+    int window_size_right = -1;
+    if (is_causal) { window_size_right = 0; }
+
+    void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
+
+    const int total_q = q.sizes()[0];
+
+    TORCH_CHECK(batch_size > 0, "batch size must be positive");
+    TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
+    TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
+
+    if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
+    if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
+
+    CHECK_SHAPE(q, total_q, num_heads, head_size_og);
+    const int total_k = k.size(0);
+    CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
+    CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
+
+    CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
+    CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
+    if (seqused_k.has_value()){
+        auto seqused_k_ = seqused_k.value();
+        TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
+        TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
+        TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
+        CHECK_SHAPE(seqused_k_, batch_size);
+    }
+
+    at::Tensor q_padded, k_padded, v_padded;
+    if (head_size_og % 8 != 0) {
+        q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+        k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+        v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+    } else {
+        q_padded = q;
+        k_padded = k;
+        v_padded = v;
+    }
+
+    at::Tensor out;
+    if (out_.has_value()) {
+        out = out_.value();
+        TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
+        CHECK_DEVICE(out);
+        TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
+        CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
+        if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
+    } else {
+        out = torch::empty_like(q_padded);
+    }
+
+    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+    const int head_size = round_multiple(head_size_og, 8);
+    const int head_size_rounded = round_multiple(head_size, 32);
+    const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
+    const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
+
+    // Otherwise the kernel will be launched from cuda:0 device
+    // Cast to char to avoid compiler warning about narrowing
+    at::cuda::CUDAGuard device_guard{(char)q.get_device()};
+
+    auto opts = q.options();
+    auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
+
+    Flash_fwd_params params;
+    set_params_fprop(params,
+                     batch_size,
+                     max_seqlen_q, max_seqlen_k,
+                     seqlen_q_rounded, seqlen_k_rounded,
+                     num_heads, num_heads_k,
+                     head_size, head_size_rounded,
+                     q_padded, k_padded, v_padded, out,
+                     cu_seqlens_q_d,
+                     cu_seqlens_k.data_ptr(),
+                     seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
+                     /*p_d=*/nullptr,
+                     softmax_lse.data_ptr(),
+                     /*p_dropout=*/0.f,
+                     softmax_scale,
+                     window_size_left,
+                     window_size_right,
+                     /*seqlenq_ngroups_swapped=*/false,
+                     /*unpadded_lse=*/true);
+    params.total_q = total_q;
+    params.total_k = total_k;
+
+    if (max_seqlen_k > 0) {
+        auto stream = at::cuda::getCurrentCUDAStream().stream();
+        run_mha_fwd(params, stream);
+    } else {
+        // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
+        out.zero_();
+        softmax_lse.fill_(std::numeric_limits<float>::infinity());
+    }
+
+    at::Tensor out_padded = out;
+    if (head_size_og % 8 != 0) {
+        out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
+        if (out_.has_value()) { out_.value().copy_(out); }
+    }
+
+    return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse};
+}
+
 void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
     // FP16_SWITCH(!params.is_bf16, [&] {
     //     HEADDIM_SWITCH(params.d, [&] {
@@ -577,4 +733,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
     m.doc() = "FlashAttention";
     m.def("fwd", &mha_fwd, "Forward pass");
     m.def("bwd", &mha_bwd, "Backward pass");
+    m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
 }

+ 201 - 0
hopper/flash_attn_interface.py

@@ -57,6 +57,83 @@ def _flash_attn_backward(
     )
     return dq, dk, dv, softmax_d
 
+def _flash_attn_varlen_forward(
+    q,
+    k,
+    v,
+    cu_seqlens_q,
+    cu_seqlens_k,
+    max_seqlen_q,
+    max_seqlen_k,
+    softmax_scale,
+    causal,
+):
+    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
+    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
+    out, q, k, v, out_padded, softmax_lse = flashattn_hopper_cuda.varlen_fwd(
+        q,
+        k,
+        v,
+        None,
+        cu_seqlens_q,
+        cu_seqlens_k,
+        None,
+        max_seqlen_q,
+        max_seqlen_k,
+        softmax_scale,
+        causal,
+    )
+    # if out.isnan().any() or softmax_lse.isnan().any():
+    #     breakpoint()
+    return out, q, k, v, out_padded, softmax_lse
+
+
+def _flash_attn_varlen_backward(
+    dout,
+    q,
+    k,
+    v,
+    out,
+    softmax_lse,
+    dq,
+    dk,
+    dv,
+    cu_seqlens_q,
+    cu_seqlens_k,
+    max_seqlen_q,
+    max_seqlen_k,
+    softmax_scale,
+    causal,
+):
+    maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
+    # dq, dk, dv are allocated by us so they should already be contiguous
+    dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
+    (
+        dq,
+        dk,
+        dv,
+        softmax_d,
+    ) = _get_fa_module().varlen_bwd(
+        dout,
+        q,
+        k,
+        v,
+        out,
+        softmax_lse,
+        dq,
+        dk,
+        dv,
+        cu_seqlens_q,
+        cu_seqlens_k,
+        max_seqlen_q,
+        max_seqlen_k,
+        softmax_scale,
+        causal,
+    )
+    # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
+    #     breakpoint()
+    return dq, dk, dv, softmax_d
+
 
 class FlashAttnFunc(torch.autograd.Function):
     @staticmethod
@@ -105,6 +182,71 @@ class FlashAttnFunc(torch.autograd.Function):
         return dq, dk, dv, None, None
 
 
+class FlashAttnVarlenFunc(torch.autograd.Function):
+    @staticmethod
+    def forward(
+        ctx,
+        q,
+        k,
+        v,
+        cu_seqlens_q,
+        cu_seqlens_k,
+        max_seqlen_q,
+        max_seqlen_k,
+        softmax_scale,
+        causal,
+    ):
+        if softmax_scale is None:
+            softmax_scale = q.shape[-1] ** (-0.5)
+        out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
+            q,
+            k,
+            v,
+            cu_seqlens_q,
+            cu_seqlens_k,
+            max_seqlen_q,
+            max_seqlen_k,
+            softmax_scale,
+            causal=causal,
+        )
+        ctx.save_for_backward(
+            q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k
+        )
+        ctx.max_seqlen_q = max_seqlen_q
+        ctx.max_seqlen_k = max_seqlen_k
+        ctx.softmax_scale = softmax_scale
+        ctx.causal = causal
+        return out, softmax_lse
+
+    @staticmethod
+    def backward(ctx, dout, *args):
+        # TODO: Uncomment these when var-seq-len is supported in bwd kernel.
+        # q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
+        # dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
+        # _flash_attn_varlen_backward(
+        #     dout,
+        #     q,
+        #     k,
+        #     v,
+        #     out,
+        #     softmax_lse,
+        #     dq,
+        #     dk,
+        #     dv,
+        #     cu_seqlens_q,
+        #     cu_seqlens_k,
+        #     ctx.max_seqlen_q,
+        #     ctx.max_seqlen_k,
+        #     ctx.softmax_scale,
+        #     ctx.causal,
+        # )
+        # dq = dq[..., : dout.shape[-1]]  # We could have padded the head dimension
+        # dk = dk[..., : dout.shape[-1]]
+        # dv = dv[..., : dout.shape[-1]]
+        # return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
+        return None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
+
+
 def flash_attn_func(
     q,
     k,
@@ -167,3 +309,62 @@ def flash_attn_func(
         softmax_scale,
         causal,
     )
+
+
+def flash_attn_varlen_func(
+    q,
+    k,
+    v,
+    cu_seqlens_q,
+    cu_seqlens_k,
+    max_seqlen_q,
+    max_seqlen_k,
+    softmax_scale=None,
+    causal=False,
+):
+    """
+    Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
+    than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
+    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
+    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
+    If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
+    For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
+        1 1 1 1 0
+        1 1 1 1 1
+    If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
+        0 0
+        0 0
+        0 0
+        1 0
+        1 1
+    If the row of the mask is all zero, the output will be zero.
+    Arguments:
+        q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
+        k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
+        v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
+        cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
+           of the sequences in the batch, used to index into q.
+        cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
+           of the sequences in the batch, used to index into kv.
+        max_seqlen_q: int. Maximum query sequence length in the batch.
+        max_seqlen_k: int. Maximum key sequence length in the batch.
+        softmax_scale: float. The scaling of QK^T before applying softmax.
+            Default to 1 / sqrt(headdim).
+        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
+    Return:
+        out: (total, nheads, headdim).
+        softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
+            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
+            normalization factor).
+    """
+    return FlashAttnVarlenFunc.apply(
+        q,
+        k,
+        v,
+        cu_seqlens_q,
+        cu_seqlens_k,
+        max_seqlen_q,
+        max_seqlen_k,
+        softmax_scale,
+        causal,
+    )

+ 27 - 12
hopper/flash_fwd_kernel.h

@@ -24,11 +24,12 @@ namespace flash {
 
 using namespace cute;
 
-template <typename Ktraits, bool Is_causal, typename TileScheduler>
+template <typename Ktraits, bool Is_causal, typename TileScheduler, typename Seqlen_traits>
 __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
-    compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal>::Params const mainloop_params,
-                    CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd<Ktraits>::Params const epilogue_params,
-                    CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params
+    compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd<Ktraits, Is_causal, Seqlen_traits>::Params const mainloop_params,
+                    CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd<Ktraits, Seqlen_traits>::Params const epilogue_params,
+                    CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params,
+                    Seqlen_traits seqlen_traits_q, Seqlen_traits seqlen_traits_k
                     ) {
 
     using Element = typename Ktraits::Element;
@@ -46,8 +47,8 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
     // static constexpr int kBlockN = Ktraits::kBlockN;
     // constexpr int kHeadDim = Ktraits::kHeadDim;
 
-    using CollectiveMainloop = CollectiveMainloopFwd<Ktraits, Is_causal>;
-    using CollectiveEpilogue = CollectiveEpilogueFwd<Ktraits>;
+    using CollectiveMainloop = CollectiveMainloopFwd<Ktraits, Is_causal, Seqlen_traits>;
+    using CollectiveEpilogue = CollectiveEpilogueFwd<Ktraits, Seqlen_traits>;
 
     using MainloopPipeline = typename Ktraits::MainloopPipeline;
     using PipelineParams = typename MainloopPipeline::Params;
@@ -115,14 +116,21 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
                 auto block_coord = work_tile_info.get_block_coord(scheduler_params);
                 auto [m_block, bidh, bidb] = block_coord;
 
-                int n_block_max = collective_mainloop.get_n_block_max(mainloop_params, m_block);
+                seqlen_traits_q.init(bidb);
+                seqlen_traits_k.init(bidb);
+                if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) {
+                    continue;
+                }
+                int n_block_max = collective_mainloop.get_n_block_max(
+                    mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
                 if (Is_causal && n_block_max <= 0) {
                     scheduler.prefetch_next_work(scheduler_params, work_tile_info);
                     scheduler.broadcast_next_work(work_tile_info);
                     continue;
                 }
                 collective_mainloop.load(mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v,
-                                         shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx);
+                                         shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx,
+                                         seqlen_traits_q, seqlen_traits_k);
                 ++work_idx;
             }
             collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v);
@@ -154,17 +162,24 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
             auto block_coord = work_tile_info.get_block_coord(scheduler_params);
             auto [m_block, bidh, bidb] = block_coord;
 
-            int n_block_max = collective_mainloop.get_n_block_max(mainloop_params, m_block);
+            seqlen_traits_q.init(bidb);
+            seqlen_traits_k.init(bidb);
+            if (m_block * kBlockM >= seqlen_traits_q.actual_seq_len) {
+                continue;
+            }
+            int n_block_max = collective_mainloop.get_n_block_max(
+                mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
             if (Is_causal && n_block_max <= 0) {  // We exit early and write 0 to gO and -inf to gLSE.
-                collective_epilogue.store_zero(epilogue_params, threadIdx.x - NumCopyThreads, block_coord);
+                collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q);
                 continue;
             }
 
             collective_mainloop.mma(mainloop_params, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v,
-                                    tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block, shared_storage);
+                                    tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block, shared_storage,
+                                    seqlen_traits_q, seqlen_traits_k);
                                     // tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage);
             collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
-                                      threadIdx.x - NumCopyThreads, block_coord);
+                                      threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q);
 
             ++work_idx;
         }

+ 61 - 24
hopper/flash_fwd_launch_template.h

@@ -14,41 +14,61 @@
 #include "tile_scheduler.hpp"
 #include "flash_fwd_kernel.h"
 #include "kernel_traits.h"
+#include "seq_len.h"
 #include "utils.h"
 
 
-template<typename Kernel_traits, bool Is_causal>
+template<typename Kernel_traits, bool Is_causal, typename Seqlen_traits>
 void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
     using Element = typename Kernel_traits::Element;
     using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
     using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
 
     // print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{});
-    using CollectiveMainloop = flash::CollectiveMainloopFwd<Kernel_traits, Is_causal>;
-    using CollectiveEpilogue = flash::CollectiveEpilogueFwd<Kernel_traits>;
-    using Scheduler = std::conditional_t<!Is_causal,
-        flash::StaticPersistentTileScheduler,
-        flash::DynamicPersistentTileScheduler<Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup>>;
-        // flash::SingleTileScheduler>;
+    using CollectiveMainloop = flash::CollectiveMainloopFwd<Kernel_traits, Is_causal, Seqlen_traits>;
+    using CollectiveEpilogue = flash::CollectiveEpilogueFwd<Kernel_traits, Seqlen_traits>;
+    using Scheduler = std::conditional_t<
+        Seqlen_traits::kUseVarSeqLen, 
+        flash::SingleTileScheduler,
+        std::conditional_t<!Is_causal,
+            flash::StaticPersistentTileScheduler,
+            flash::DynamicPersistentTileScheduler<Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup>
+    >>;
+    // using Scheduler = flash::SingleTileScheduler;
+    Seqlen_traits seqlen_traits_q(
+        params.total_q, params.seqlen_q, params.cu_seqlens_q);
+    Seqlen_traits seqlen_traits_k(
+        params.total_k, params.seqlen_k, params.cu_seqlens_k, params.seqused_k);
     typename CollectiveMainloop::Params mainloop_params =
         CollectiveMainloop::to_underlying_arguments({
             static_cast<Element const*>(params.q_ptr),
-            {params.seqlen_q, params.d, params.h, params.b},  // shape_Q
-            {params.q_row_stride, _1{}, params.q_head_stride, params.q_batch_stride},  // stride_Q
+            seqlen_traits_q.get_gmem_layout(
+                params.seqlen_q, params.d, params.h, params.b, 
+                params.q_row_stride, params.q_head_stride, params.q_batch_stride
+            ),  // layout_Q
             static_cast<Element const*>(params.k_ptr),
-            {params.seqlen_k, params.d, params.h_k, params.b},  // shape_K
-            {params.k_row_stride, _1{}, params.k_head_stride, params.k_batch_stride},  // stride_K
+            seqlen_traits_k.get_gmem_layout(
+                params.seqlen_k, params.d, params.h_k, params.b, 
+                params.k_row_stride, params.k_head_stride, params.k_batch_stride
+            ),  // layout_K
             static_cast<Element const*>(params.v_ptr),
-            {params.v_row_stride, _1{}, params.v_head_stride, params.v_batch_stride},  // stride_V
+            seqlen_traits_k.get_gmem_layout(
+                params.seqlen_k, params.d, params.h_k, params.b, 
+                params.v_row_stride, params.v_head_stride, params.v_batch_stride
+            ),  // layout_V
             params.scale_softmax_log2
         });
     typename CollectiveEpilogue::Params epilogue_params =
         CollectiveEpilogue::to_underlying_arguments({
             static_cast<Element*>(params.o_ptr),
-            {params.seqlen_q, params.d, params.h, params.b},  // shape_O
-            {params.o_row_stride, _1{}, params.o_head_stride, params.o_batch_stride},  // stride_O
+            seqlen_traits_q.get_gmem_layout(
+                params.seqlen_q, params.d, params.h, params.b,
+                params.o_row_stride, params.o_head_stride, params.o_batch_stride
+            ),  // layout_O
             static_cast<float*>(params.softmax_lse_ptr),
-            {_1{}, params.seqlen_q, params.h * params.seqlen_q},  // stride_LSE
+            seqlen_traits_q.get_lse_gmem_layout(
+                params.seqlen_q, params.h, params.b
+            )  // layout_LSE
         });
 
     int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
@@ -58,7 +78,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
 
     // Get the ptr to kernel function.
     void *kernel;
-    kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Scheduler>;
+    kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Scheduler, Seqlen_traits>;
     int smem_size = sizeof(typename Kernel_traits::SharedStorage);
     // int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
     // int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
@@ -81,7 +101,9 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
     dim3 block_dims(ctaSize);
     dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
     cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
-    cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, epilogue_params, scheduler_params);
+    cutlass::launch_kernel_on_cluster(
+        launch_params, kernel, mainloop_params, epilogue_params, 
+        scheduler_params, seqlen_traits_q, seqlen_traits_k);
     CHECK_CUDA_KERNEL_LAUNCH();
 }
 
@@ -89,7 +111,12 @@ template<typename T>
 void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 64;
     BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-        run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 2, false, 1, T>, Is_causal>(params, stream);
+        SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
+            run_flash_fwd<
+                Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 2, false, 1, T>, 
+                Is_causal, Seqlen_traits
+            >(params, stream);
+        });
     });
 }
 
@@ -97,9 +124,14 @@ template<typename T>
 void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 128;
     BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-        // Only use Cluster if number of tiles along seqlen_q is even
-        BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0, UseCluster, [&] {
-            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, Is_causal ? 128 : 176, 12, 2, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream);
+        SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
+            // Only use Cluster if number of tiles along seqlen_q is even and not Is_causal
+            BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
+                run_flash_fwd<
+                    Flash_fwd_kernel_traits<Headdim, 128, Is_causal ? 128 : 176, 12, 2, false, UseCluster ? 2 : 1, T>, 
+                    Is_causal, Seqlen_traits
+                >(params, stream);
+            });
         });
     });
 }
@@ -108,9 +140,14 @@ template<typename T>
 void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
     constexpr static int Headdim = 256;
     BOOL_SWITCH(params.is_causal, Is_causal, [&] {
-        // Only use Cluster if number of tiles along seqlen_q is even
-        BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0, UseCluster, [&] {
-            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream);
+        SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
+            // Only use Cluster if number of tiles along seqlen_q is even
+            BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
+                run_flash_fwd<
+                    Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, UseCluster ? 2 : 1, T>, 
+                    Is_causal, Seqlen_traits
+                >(params, stream);
+            });
         });
     });
 }

+ 47 - 33
hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp

@@ -21,7 +21,7 @@ namespace flash {
 
 using namespace cute;
 
-template <typename Ktraits, bool Is_causal>
+template <typename Ktraits, bool Is_causal, typename Seqlen_traits>
 struct CollectiveMainloopFwd {
 
     using Element = typename Ktraits::Element;
@@ -64,19 +64,24 @@ struct CollectiveMainloopFwd {
     //     decltype(tile_to_shape(SmemLayoutAtomVTMA{},
     //                            make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
 
-    using ShapeQKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>;  // (seqlen, d, head, batch)
-    using StrideQKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
-
     using TMA_Q = decltype(make_tma_copy(
         GmemTiledCopyQ{},
-        make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), repeat_like(StrideQKV{}, int32_t(0)), StrideQKV{}),
+        make_tensor(
+            make_gmem_ptr(static_cast<Element const*>(nullptr)), 
+            repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)), 
+            typename Seqlen_traits::StrideT{}
+        ),
         SmemLayoutQ{},
         select<0, 2>(TileShape_MNK{}),
         _1{}));  // no mcast for Q
 
     using TMA_KV = decltype(make_tma_copy(
         GmemTiledCopyKV{},
-        make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), repeat_like(StrideQKV{}, int32_t(0)), StrideQKV{}),
+        make_tensor(
+            make_gmem_ptr(static_cast<Element const*>(nullptr)), 
+            repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)), 
+            typename Seqlen_traits::StrideT{}
+        ),
         take<0, 2>(SmemLayoutK{}),
         select<1, 2>(TileShape_MNK{}),
         size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
@@ -95,20 +100,19 @@ struct CollectiveMainloopFwd {
     // Host side kernel arguments
     struct Arguments {
         Element const* ptr_Q;
-        ShapeQKV const shape_Q;
-        StrideQKV const stride_Q;
+        typename Seqlen_traits::LayoutT layout_Q;
         Element const* ptr_K;
-        ShapeQKV const shape_K;
-        StrideQKV const stride_K;
+        typename Seqlen_traits::LayoutT layout_K;
         Element const* ptr_V;
-        StrideQKV const stride_V;
+        typename Seqlen_traits::LayoutT layout_V;
         float const softmax_scale_log2;
     };
 
     // Device side kernel params
     struct Params {
-        ShapeQKV const shape_Q;
-        ShapeQKV const shape_K;
+        typename Seqlen_traits::LayoutT layout_Q;
+        typename Seqlen_traits::LayoutT layout_K;
+        typename Seqlen_traits::LayoutT layout_V;
         cutlass::FastDivmod qhead_per_khead_divmod;
         TMA_Q tma_load_Q;
         TMA_KV tma_load_K, tma_load_V;
@@ -118,29 +122,29 @@ struct CollectiveMainloopFwd {
 
     static Params
     to_underlying_arguments(Arguments const& args) {
-        Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q);
+        Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q);
         TMA_Q tma_load_Q = make_tma_copy(
             GmemTiledCopyQ{},
             mQ,
             SmemLayoutQ{},
             select<0, 2>(TileShape_MNK{}),
             _1{}); // no mcast for Q
-        Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K);
+        Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K);
         TMA_KV tma_load_K = make_tma_copy(
             GmemTiledCopyKV{},
             mK,
             SmemLayoutK{}(_, _, _0{}),
             select<1, 2>(TileShape_MNK{}),
             size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
-        Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_K, args.stride_V);
+        Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V);
         TMA_KV tma_load_V = make_tma_copy(
             GmemTiledCopyKV{},
             mV,
             SmemLayoutV{}(_, _, _0{}),
             select<1, 2>(TileShape_MNK{}),
             size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
-        return {args.shape_Q, args.shape_K,
-                cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))),
+        return {args.layout_Q, args.layout_K, args.layout_V,
+                cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))),
                 tma_load_Q, tma_load_K, tma_load_V,
                 args.softmax_scale_log2};
     }
@@ -154,11 +158,15 @@ struct CollectiveMainloopFwd {
     }
 
     CUTLASS_DEVICE
-    int get_n_block_max(Params const& mainloop_params, int m_block) {
+    int get_n_block_max(
+          Params const& mainloop_params, int m_block, 
+          const Seqlen_traits& seqlen_traits_q,
+          const Seqlen_traits& seqlen_traits_k
+        ) {
         static constexpr int kBlockM = get<0>(TileShape_MNK{});
         static constexpr int kBlockN = get<1>(TileShape_MNK{});
-        int const seqlen_q = get<0>(mainloop_params.shape_Q);
-        int const seqlen_k = get<0>(mainloop_params.shape_K);
+        int const seqlen_q = seqlen_traits_q.actual_seq_len;
+        int const seqlen_k = seqlen_traits_k.actual_seq_len;
         int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
         if constexpr (Is_causal) {
             n_block_max = std::min(n_block_max,
@@ -179,16 +187,18 @@ struct CollectiveMainloopFwd {
          typename Scheduler::Params const& scheduler_params,
          typename Scheduler::WorkTileInfo& work_tile_info,
          cute::tuple<int32_t, int32_t, int32_t> block_coord,
-         int work_idx
+         int work_idx,
+         const Seqlen_traits& seqlen_traits_q,
+         const Seqlen_traits& seqlen_traits_k
          ) {
 
         Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
         Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
         Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
 
-        Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.shape_Q);
-        Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.shape_K);
-        Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.shape_K);
+        Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
+        Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape());
+        Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape());
 
         auto [m_block, bidh, bidb] = block_coord;
         int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
@@ -197,9 +207,12 @@ struct CollectiveMainloopFwd {
         uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
         constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
         uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
-        Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{}));  // (M, K)
-        Tensor gK = local_tile(mK(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}));  // (N, K, _)
-        Tensor gV = local_tile(mV(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}));  // (N, K, _)
+        Tensor gQ = seqlen_traits_q.get_local_tile_tensor(
+            mQ, select<0, 2>(TileShape_MNK{}), bidh, bidb)(_, _, m_block);  // (M, K)
+        Tensor gK = seqlen_traits_k.get_local_tile_tensor(
+            mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb);  // (N, K, _)
+        Tensor gV = seqlen_traits_k.get_local_tile_tensor(
+            mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb);  // (N, K, _)
 
         Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
         Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
@@ -218,7 +231,7 @@ struct CollectiveMainloopFwd {
             }
         }
 
-        int n_block_max = get_n_block_max(mainloop_params, m_block);
+        int n_block_max = get_n_block_max(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
         int n_block = n_block_max - 1;
 
         int lane_predicate = cute::elect_one_sync();
@@ -331,7 +344,9 @@ struct CollectiveMainloopFwd {
         int thread_idx,
         int work_idx,
         int m_block,
-        SharedStorage& shared_storage
+        SharedStorage& shared_storage,
+        const Seqlen_traits& seqlen_traits_q,
+        const Seqlen_traits& seqlen_traits_k
         ) {
         static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");
 
@@ -360,8 +375,8 @@ struct CollectiveMainloopFwd {
         };
 
         tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
-        int const seqlen_q = get<0>(mainloop_params.shape_Q);
-        int const seqlen_k = get<0>(mainloop_params.shape_K);
+        int const seqlen_q = seqlen_traits_q.actual_seq_len;
+        int const seqlen_k = seqlen_traits_k.actual_seq_len;
         int n_block = n_block_count - 1;
 
         cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(work_idx % 2));
@@ -483,4 +498,3 @@ struct CollectiveMainloopFwd {
 };
 
 } // namespace flash
-

+ 168 - 0
hopper/seq_len.h

@@ -0,0 +1,168 @@
+/******************************************************************************
+ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+#include <cutlass/cutlass.h>
+#include <cute/layout.hpp>
+
+namespace flash {
+
+static constexpr int kMaxTileSize = 128;
+
+template <bool UseVarSeqLen> class SeqLenTraits {
+public:
+  // Total number of queries / keys. Unpadded.
+  int sum_s = 0;
+  // seq len offsets.
+  int *cu_seq_len = nullptr;
+  // actual seq len array.
+  int *seq_used = nullptr;
+  // seq len of the current batch.
+  int actual_seq_len = -1;
+
+  // Whether this is for fixed-seq-len or var-seq-len.
+  static constexpr bool kUseVarSeqLen = UseVarSeqLen;
+
+  using ShapeT = std::conditional_t<
+      UseVarSeqLen, 
+      cute::Shape<int32_t, int32_t, int32_t>, 
+      cute::Shape<int32_t, int32_t, int32_t, int32_t>
+  >;
+  using StrideT = std::conditional_t<
+      UseVarSeqLen, 
+      cute::Shape<int64_t, _1, int64_t>, 
+      cute::Shape<int64_t, _1, int64_t, int64_t>
+  >;
+  using LayoutT = cute::Layout<ShapeT, StrideT>;
+
+  using ShapeLseT = std::conditional_t<
+      UseVarSeqLen, 
+      cute::Shape<int32_t, int32_t>, 
+      cute::Shape<int32_t, int32_t, int32_t>
+  >;
+  using StrideLseT = std::conditional_t<
+      UseVarSeqLen, 
+      cute::Shape<int64_t, _1>, 
+      cute::Shape<int64_t, int64_t, _1>
+  >;
+  using LayoutLseT = cute::Layout<ShapeLseT, StrideLseT>;
+
+  CUTLASS_HOST SeqLenTraits() {}
+
+  CUTLASS_HOST SeqLenTraits(
+      int sum_s, int max_seq_len, int *cu_seq_len = nullptr, int *seq_used = nullptr): 
+      sum_s(sum_s), cu_seq_len(cu_seq_len), seq_used(seq_used), actual_seq_len(max_seq_len) {}
+
+  // Returns the layout of a tensor in MKHB format in global memory.
+  // padded: only useful for var-seq-len for dq_accum and softmax_d.
+  CUTLASS_HOST_DEVICE auto get_gmem_layout(
+      int m, int k, int h, int b, 
+      int64_t m_stride, int64_t h_stride, int64_t b_stride,
+      bool padded = false) const {
+    static_assert(!UseVarSeqLen, "Default implementation is for FixedSeqLen.");
+    return make_layout(make_shape(m, k, h, b),
+                       make_stride(m_stride, cute::_1{}, h_stride, b_stride));
+  }
+
+  // Returns the layout of a tensor in MKHB format in global memory.
+  // padded: only useful for var-seq-len for dq_accum and softmax_d.
+  CUTLASS_HOST_DEVICE auto get_lse_gmem_layout(
+      int m, int h, int b, bool padded = false) const {
+    static_assert(!UseVarSeqLen, "Default implementation is for FixedSeqLen.");
+    return make_layout(make_shape(b, h, m),
+                       make_stride(int64_t(h * m), int64_t(m), cute::_1()));
+  }
+
+  CUTLASS_DEVICE void init(int bidb) {}
+
+  template <typename MTensor, typename Shape>
+  CUTLASS_DEVICE auto get_local_tile_tensor(
+      const MTensor &m_tensor, const Shape &tile_shape, 
+      int bidh, int bidb, bool padded = false) const {
+    auto g_tensor = local_tile(
+      m_tensor(_, _, bidh, bidb), tile_shape, make_coord(_, _0{}));
+    return g_tensor;
+  }
+
+  template <typename MTensor, typename Shape>
+  CUTLASS_DEVICE auto get_lse_local_tile_tensor(
+      const MTensor &m_tensor, const Shape &tile_shape, 
+      int bidh, int bidb, bool padded = false) const {
+    auto g_tensor = local_tile(m_tensor(bidb, bidh, _), tile_shape, make_coord(_));
+    return g_tensor;
+  }
+};
+
+using FixedSeqLenTraits = SeqLenTraits<false>;
+
+using VarSeqLenTraits = SeqLenTraits<true>;
+
+// Returns the static layout of a var-seq-len tensor in global memory based on
+// max_seq_len and max_batch_size.
+// padded: only useful for var-seq-len for dq_accum and softmax_d.
+// When padded is True, use B_M + kMaxTileSize * B as the total B_M.
+template <>
+CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_gmem_layout(
+    int m, int k, int h, int b, 
+    int64_t m_stride, int64_t h_stride, int64_t b_stride,
+    bool padded) const {
+  return make_layout(
+    make_shape(sum_s + (padded ? kMaxTileSize * b : 0), k, h), 
+    make_stride(m_stride, cute::_1{}, h_stride));
+}
+
+// padded: only useful for var-seq-len for dq_accum and softmax_d.
+// When padded is True, use B_M + kMaxTileSize * B as the total B_M.
+template <>
+CUTLASS_HOST_DEVICE auto VarSeqLenTraits::get_lse_gmem_layout(
+    int m, int h, int b, bool padded) const {
+  return make_layout(
+    make_shape(h, sum_s + (padded ? kMaxTileSize * b : 0)), 
+    make_stride(int64_t(sum_s + (padded ? kMaxTileSize * b : 0)), cute::_1()));
+}
+
+template <>
+CUTLASS_DEVICE void VarSeqLenTraits::init(int bidb) {
+  actual_seq_len = 
+      seq_used ? seq_used[bidb] : (cu_seq_len[bidb + 1] - cu_seq_len[bidb]);
+}
+
+template <>
+template <typename MTensor, typename Shape>
+CUTLASS_DEVICE auto VarSeqLenTraits::get_local_tile_tensor(
+    const MTensor &m_tensor, const Shape &tile_shape,
+    int bidh, int bidb, bool padded) const {
+  auto g_offset = local_tile(
+      m_tensor(_, _, bidh), 
+      cute::make_shape(1, get<1>(tile_shape)), 
+      make_coord(cu_seq_len[bidb] + (padded ? kMaxTileSize * bidb : 0), _0{}));
+  auto g_sequence = make_tensor(
+      g_offset.data(), 
+      make_layout(
+        cute::make_shape(actual_seq_len, get<1>(tile_shape)), 
+        g_offset.stride()
+      ));
+  auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
+  return g_tensor;
+}
+
+template <>
+template <typename MTensor, typename Shape>
+CUTLASS_DEVICE auto VarSeqLenTraits::get_lse_local_tile_tensor(
+    const MTensor &m_tensor, const Shape &tile_shape,
+    int bidh, int bidb, bool padded) const {
+  auto g_offset = local_tile(
+      m_tensor(bidh, _), cute::make_shape(_1{}), 
+      make_coord(cu_seq_len[bidb] + (padded ? kMaxTileSize * bidb : 0)));
+  auto g_sequence = make_tensor(
+      g_offset.data(), 
+      make_layout(cute::make_shape(actual_seq_len), cute::make_shape(_1{})));
+  auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_));
+  return g_tensor;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+}  // namespace flash

+ 6 - 10
hopper/static_switch.h

@@ -66,18 +66,14 @@
     }                                                                          \
   }()
 
-#define SEQLEN_SWITCH(USE_VAR_SEQ_LEN, SEQ_LEN_OUT_OF_BOUND_CHECK, ...)        \
+#define SEQLEN_SWITCH(USE_VAR_SEQ_LEN, NAME, ...)                              \
   [&] {                                                                        \
-    if (!USE_VAR_SEQ_LEN) {                                                    \
-      if (SEQ_LEN_OUT_OF_BOUND_CHECK) {                                        \
-        using kSeqLenTraitsType = FixedSeqLenTraits<true>;                     \
-        return __VA_ARGS__();                                                  \
-      } else {                                                                 \
-        using kSeqLenTraitsType = FixedSeqLenTraits<false>;                    \
-        return __VA_ARGS__();                                                  \
-      }                                                                        \
+    bool useSeqLen = USE_VAR_SEQ_LEN;                                          \
+    if (useSeqLen) {                                                           \
+      using NAME = flash::VarSeqLenTraits;                                     \
+      return __VA_ARGS__();                                                    \
     } else {                                                                   \
-      using kSeqLenTraitsType = VarSeqLenTraits;                               \
+      using NAME = flash::FixedSeqLenTraits;                                   \
       return __VA_ARGS__();                                                    \
     }                                                                          \
   }() 

+ 170 - 113
hopper/test_flash_attn.py

@@ -5,40 +5,12 @@ import torch
 import torch.nn.functional as F
 
 from einops import rearrange, repeat
-from flash_attn_interface import flash_attn_func
+from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
+from tests.test_util import generate_random_padding_mask, generate_qkv, construct_local_mask, attention_ref
 
 ABS_TOL = 5e-3
 REL_TOL = 1e-1
 
-def construct_local_mask(
-    seqlen_q,
-    seqlen_k,
-    window_size=(-1, -1),  # -1 means infinite window size
-    query_padding_mask=None,
-    key_padding_mask=None,
-    device=None,
-):
-    row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
-    col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
-    sk = (
-        seqlen_k
-        if key_padding_mask is None
-        else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
-    )
-    sq = (
-        seqlen_q
-        if query_padding_mask is None
-        else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
-    )
-    if window_size[0] < 0:
-        return col_idx > row_idx + sk - sq + window_size[1]
-    else:
-        sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
-        return torch.logical_or(
-            col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
-            col_idx < row_idx + sk - sq - window_size[0],
-        )
-
 def print_diffs(out, out_ref):
     out_1d = out.flatten()
     out_ref_1d = out_ref.flatten()
@@ -51,86 +23,6 @@ def print_diffs(out, out_ref):
             print(f"==== diff ==== {idx}, test: {e_o}, ref: {e_o_ref}")
 
 
-def attention_ref(
-    q,
-    k,
-    v,
-    query_padding_mask=None,
-    key_padding_mask=None,
-    attn_bias=None,
-    dropout_p=0.0,
-    dropout_mask=None,
-    causal=False,
-    upcast=True,
-    reorder_ops=False,
-):
-    """
-    Arguments:
-        q: (batch_size, seqlen_q, nheads, head_dim)
-        k: (batch_size, seqlen_k, nheads, head_dim)
-        v: (batch_size, seqlen_k, nheads, head_dim)
-        query_padding_mask: (batch_size, seqlen_q)
-        key_padding_mask: (batch_size, seqlen_k)
-        attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
-        dropout_p: float
-        dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
-        causal: whether to apply causal masking
-        upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
-            output back to fp16/bf16.
-        reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
-            without changing the math. This is to estimate the numerical error from operation
-            reordering.
-    Output:
-        output: (batch_size, seqlen_q, nheads, head_dim)
-        attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
-    """
-    dtype_og = q.dtype
-    if upcast:
-        q, k, v = q.float(), k.float(), v.float()
-    seqlen_q, seqlen_k = q.shape[1], k.shape[1]
-    k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
-    v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
-    d = q.shape[-1]
-    if not reorder_ops:
-        scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
-    else:
-        scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
-    if key_padding_mask is not None:
-        scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
-    if causal:
-        local_mask = construct_local_mask(
-            seqlen_q,
-            seqlen_k,
-            (-1, 0),
-            None,
-            None,
-            q.device,
-        )
-        scores.masked_fill_(local_mask, float("-inf"))
-    if attn_bias is not None:
-        scores = scores + attn_bias
-    attention = torch.softmax(scores, dim=-1).to(v.dtype)
-    # We want to mask here so that the attention matrix doesn't have any NaNs
-    # Otherwise we'll get NaN in dV
-    if query_padding_mask is not None:
-        attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
-    # Some rows might be completely masked out so we fill them with zero instead of NaN
-    if causal:
-        attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
-    dropout_scaling = 1.0 / (1 - dropout_p)
-    # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
-    # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
-    if dropout_mask is not None:
-        attention_drop = attention.masked_fill(~dropout_mask, 0.0)
-    else:
-        attention_drop = attention
-    output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
-    if query_padding_mask is not None:
-        output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
-    return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
-
-
-
 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
 # @pytest.mark.parametrize("dtype", [torch.bfloat16])
 @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@@ -142,10 +34,11 @@ def attention_ref(
 # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
 # @pytest.mark.parametrize('d', [56, 80])
 @pytest.mark.parametrize("d", [64, 128, 256])
-# @pytest.mark.parametrize("d", [256])
+# @pytest.mark.parametrize("d", [128])
 @pytest.mark.parametrize(
     "seqlen_q,seqlen_k",
     [
+        (257, 1),
         (64, 128),
         (128, 128),
         (256, 256),
@@ -175,8 +68,9 @@ def test_flash_attn_output(
     batch_size = 9
     nheads = 6
     nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
-    # batch_size = 1
-    # nheads = 1
+    # nheads_kv = 2
+    # batch_size = 9
+    # nheads = 6
     q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
     k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
     v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True)
@@ -244,9 +138,172 @@ def test_flash_attn_output(
 
     # Check that FlashAttention's numerical error is at most twice the numerical error
     # of a Pytorch implementation.
+    # breakpoint()
     assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
 
     # if d <= 128:
     #     assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
     #     assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
     #     assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
+
+
+@pytest.mark.parametrize("dtype", [torch.float16])
+@pytest.mark.parametrize("causal", [False, True])
+@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
+# @pytest.mark.parametrize('causal', [True])
+# @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
+# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
+# @pytest.mark.parametrize('d', [128])
+@pytest.mark.parametrize("d", [64, 128, 256])
+@pytest.mark.parametrize(
+    "seqlen_q,seqlen_k",
+    [
+        (1, 1),
+        (1, 3),
+        (2, 1),
+        (511, 1),
+        (3, 513),
+        (64, 128),
+        (113, 203),
+        (128, 128),
+        (128, 217),
+        (113, 211),
+        (108, 256),
+        (256, 512),
+        (384, 256),
+        (512, 256),
+        (640, 128),
+        (1024, 1024),
+        (1023, 1024),
+        (1024, 1023),
+        (2048, 2048),
+    ],
+)
+# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
+def test_flash_attn_varlen_output(
+    seqlen_q, seqlen_k, d, causal, mha_type, dtype
+):
+    if (
+        max(seqlen_q, seqlen_k) >= 2048
+        and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
+    ):
+        pytest.skip()  # Reference implementation OOM
+    device = "cuda"
+    # set seed
+    torch.random.manual_seed(0)
+    # batch_size = 1
+    # nheads = 1
+    batch_size = 9
+    nheads = 6
+    nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
+ 
+    q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
+    k = torch.randn(
+        batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True
+    )
+    v = torch.randn(
+        batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True
+    )
+
+    query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
+    key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
+    # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
+
+    (
+        q_unpad,
+        k_unpad,
+        v_unpad,
+        cu_seqlens_q,
+        cu_seqlens_k,
+        max_seqlen_q,
+        max_seqlen_k,
+        q,
+        k,
+        v,
+        output_pad_fn,
+        dq_pad_fn,
+        dk_pad_fn,
+    ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
+    # print("cu_seqlens_q: ", cu_seqlens_q)
+    # print("cu_seqlens_k: ", cu_seqlens_k)
+    # print("q_unpad, shape: ", q_unpad.shape)
+    # print("k_unpad, shape: ", k_unpad.shape)
+    # print("v_unpad, shape: ", v_unpad.shape)
+    out_unpad, sm_lse = flash_attn_varlen_func(
+        q_unpad,
+        k_unpad,
+        v_unpad,
+        cu_seqlens_q,
+        cu_seqlens_k,
+        max_seqlen_q,
+        max_seqlen_k,
+        causal=causal,
+    )
+    out = output_pad_fn(out_unpad)
+    dropout_mask = None
+
+    out_ref, attn_ref = attention_ref(
+        q,
+        k,
+        v,
+        query_padding_mask,
+        key_padding_mask,
+        causal=causal,
+    )
+    out_pt, attn_pt = attention_ref(
+        q,
+        k,
+        v,
+        query_padding_mask,
+        key_padding_mask,
+        causal=causal,
+        upcast=False,
+        reorder_ops=True,
+    )
+
+    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
+    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
+    print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
+    print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
+
+    # g = torch.randn_like(out)
+    # if d <= 128:
+    #     (
+    #         dq_unpad,
+    #         dk_unpad,
+    #         dv_unpad,
+    #     ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
+    #     dk = dk_pad_fn(dk_unpad)
+    #     dv = dk_pad_fn(dv_unpad)
+    #     (
+    #         dq_ref,
+    #         dk_ref,
+    #         dv_ref,
+    #     ) = torch.autograd.grad(out_ref, (q, k, v), g)
+    #     (
+    #         dq_pt,
+    #         dk_pt,
+    #         dv_pt,
+    #     ) = torch.autograd.grad(out_pt, (q, k, v), g)
+    #     dq = dq_pad_fn(dq_unpad)
+    #     print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
+    #     print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
+    #     print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
+    #     print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
+    #     print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
+    #     print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
+    #     print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
+    #     print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
+    #     print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
+    #     print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
+    #     print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
+    #     print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
+
+    # Check that FlashAttention's numerical error is at most twice the numerical error
+    # of a Pytorch implementation.
+    assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
+
+    # if d <= 128:
+    #     assert (dq - dq_ref).abs().max().item() < 1e-4 or (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
+    #     assert (dk - dk_ref).abs().max().item() < 1e-4 or (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
+    #     assert (dk - dk_ref).abs().max().item() < 1e-4 or (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()

+ 90 - 0
hopper/utils.h

@@ -15,6 +15,7 @@
 #endif
 
 #include <cute/tensor.hpp>
+#include <cute/atom/copy_atom.hpp>
 
 #include <cutlass/array.h>
 #include <cutlass/cutlass.h>
@@ -228,4 +229,93 @@ __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layou
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
+template <int NumCopyThreads, typename ElemO, typename TMACopyO, typename LayoutO, 
+          typename TileShapeO, typename SMemO, typename SeqLenTraits>
+__forceinline__ __device__ void write_tma(
+        ElemO* O, const TMACopyO& tma_store_O,
+        const LayoutO& layout_O, const TileShapeO& tile_shape_O,
+        const SMemO& sO, int m_block, int bidh, int bidb,
+        const SeqLenTraits& seqlen_traits_o, int write_warp_idx) {
+    Tensor mO = tma_store_O.get_tma_tensor(layout_O.shape());
+    Tensor gO = seqlen_traits_o.get_local_tile_tensor(
+        mO, tile_shape_O, bidh, bidb
+    )(_, _, m_block);  // (M, K)
+    auto block_tma_O = tma_store_O.get_slice(_0{});
+    Tensor tOgO = block_tma_O.partition_D(gO);  // (TMA, TMA_M, TMA_K)
+    Tensor tOsO = block_tma_O.partition_S(sO);  // (TMA, TMA_M, TMA_K)
+
+    int const lane_predicate = cute::elect_one_sync();
+    int const warp_idx = cutlass::canonical_warp_idx_sync();
+    if (warp_idx == write_warp_idx && lane_predicate) {
+        cute::copy(tma_store_O, tOsO, tOgO);
+        tma_store_arrive();
+    }
+    // Note: no wait here.
+    // tma_store_wait<0>();
+}
+
+template <int NumCopyThreads, typename ElemO, typename TiledCopyO, typename LayoutO, 
+          typename TileShapeO, typename SMemO, typename SeqLenTraits>
+__forceinline__ __device__ void write_tiled(
+        ElemO* O, const TiledCopyO& tiled_copy_O,
+        const LayoutO& layout_O, const TileShapeO& tile_shape_O,
+        const SMemO& sO, int m_block, int bidh, int bidb,
+        const SeqLenTraits& seqlen_traits_o) {
+    Tensor mO = make_tensor(make_gmem_ptr(O), layout_O);
+    Tensor gO = seqlen_traits_o.get_local_tile_tensor(
+        mO, tile_shape_O, bidh, bidb
+    )(_, _, m_block);  // (M, K)
+
+    ThrCopy thr_copy_O = tiled_copy_O.get_slice(threadIdx.x - NumCopyThreads);
+    Tensor tOgO = thr_copy_O.partition_D(gO); // (CPY,CPY_M,CPY_K,k)
+    Tensor tOsO = thr_copy_O.partition_S(sO); // (CPY,CPY_M,CPY_K)
+
+    // Prepare for TiledCopy.
+    // Grouping is needed because cute::copy_if() does group_modes<1, R> for src and dst.
+    // After grouping, the first dim is number of elements to read together.
+    Tensor tOsOFlatten = cute::flatten(tOsO);
+    Tensor tOsOGroup = cute::group_modes<1, rank(tOsOFlatten)>(tOsOFlatten);
+    Tensor tOgOFlatten = cute::flatten(tOgO);
+    Tensor tOgOGroup = cute::group_modes<1, rank(tOgOFlatten)>(tOgOFlatten);
+
+    // Get thread coords to global index mapping.
+    Tensor gOCounting = cute::make_identity_tensor(gO.shape());
+    Tensor tSgOCounting = thr_copy_O.partition_D(gOCounting);
+    Tensor tSgOCountingFlatten = cute::flatten(tSgOCounting);
+    Tensor tSgOCountingGrouped =
+        cute::group_modes<1, rank(tSgOCountingFlatten)>(tSgOCountingFlatten);
+
+    // Write out to GMEM.
+    const int kNumMsPerTile = get<0>(tile_shape_O);
+    int cta_m = std::min(
+        seqlen_traits_o.actual_seq_len - m_block * kNumMsPerTile, kNumMsPerTile
+    );
+    if (cta_m == kNumMsPerTile) {
+        copy(tiled_copy_O, tOsOGroup, tOgOGroup);
+    } else {
+        auto predicate_fn = [&](auto coords) {
+            auto s_coords = tSgOCountingGrouped(_0{}, coords);
+            return elem_less(get<0>(s_coords), cta_m);
+        };
+        copy_if(tiled_copy_O, predicate_fn, tOsOGroup, tOgOGroup);
+    }
+}
+
+template <bool IsTMACopy, int NumCopyThreads, typename ElemO, 
+          typename TMACopyO, typename TiledCopyO, typename LayoutO, 
+          typename TileShapeO, typename SMemO, typename SeqLenTraits>
+__forceinline__ __device__ void write_O(
+        ElemO* O, const TMACopyO& tma_copy_O, const TiledCopyO& tiled_copy_O,
+        const LayoutO& layout_O, const TileShapeO& tile_shape_O,
+        const SMemO& sO, int m_block, int bidh, int bidb,
+        const SeqLenTraits& seqlen_traits_o, int write_warp_idx) {
+    if constexpr (IsTMACopy) {
+        write_tma<NumCopyThreads>(O, tma_copy_O, layout_O, tile_shape_O, sO, m_block, bidh, bidb, seqlen_traits_o, write_warp_idx);
+    } else {
+        write_tiled<NumCopyThreads>(O, tiled_copy_O, layout_O, tile_shape_O, sO, m_block, bidh, bidb, seqlen_traits_o);
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
 }  // namespace flash

+ 254 - 0
tests/test_util.py

@@ -0,0 +1,254 @@
+import math
+
+import torch
+from einops import rearrange, repeat
+from flash_attn.bert_padding import pad_input, unpad_input
+
+
+def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
+    assert mode in ["full", "random", "third"]
+    if mode == "full":
+        lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
+    elif mode == "random":
+        lengths = torch.randint(
+            max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
+        )
+    elif mode == "third":
+        lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
+    padding_mask = (
+        repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
+    )
+    return padding_mask
+
+
+def generate_qkv(
+    q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False
+):
+    """
+    Arguments:
+        q: (batch_size, seqlen_q, nheads, d)
+        k: (batch_size, seqlen_k, nheads_k, d)
+        v: (batch_size, seqlen_k, nheads_k, d)
+        query_padding_mask: (batch_size, seqlen), bool
+        key_padding_mask: (batch_size, seqlen), bool
+    """
+    assert not (kvpacked and qkvpacked)
+    batch_size, seqlen_q, nheads, d = q.shape
+    _, seqlen_k, nheads_k, _ = k.shape
+    assert k.shape == (batch_size, seqlen_k, nheads_k, d)
+    assert v.shape == (batch_size, seqlen_k, nheads_k, d)
+
+    if query_padding_mask is not None:
+        q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
+        output_pad_fn = lambda output_unpad: pad_input(
+            output_unpad, indices_q, batch_size, seqlen_q
+        )
+    else:
+        q_unpad = rearrange(q, "b s h d -> (b s) h d")
+        cu_seqlens_q = torch.arange(
+            0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
+        )
+        max_seqlen_q = seqlen_q
+        output_pad_fn = lambda output_unpad: rearrange(
+            output_unpad, "(b s) h d -> b s h d", b=batch_size
+        )
+
+    if key_padding_mask is not None:
+        k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
+        v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
+    else:
+        k_unpad = rearrange(k, "b s h d -> (b s) h d")
+        v_unpad = rearrange(v, "b s h d -> (b s) h d")
+        cu_seqlens_k = torch.arange(
+            0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
+        )
+        max_seqlen_k = seqlen_k
+
+    if qkvpacked:
+        assert (query_padding_mask == key_padding_mask).all()
+        assert nheads == nheads_k
+        qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
+        qkv = torch.stack([q, k, v], dim=2)
+        if query_padding_mask is not None:
+            dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
+        else:
+            dqkv_pad_fn = lambda dqkv_unpad: rearrange(
+                dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
+            )
+        return (
+            qkv_unpad.detach().requires_grad_(),
+            cu_seqlens_q,
+            max_seqlen_q,
+            qkv.detach().requires_grad_(),
+            output_pad_fn,
+            dqkv_pad_fn,
+        )
+    elif kvpacked:
+        kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
+        kv = torch.stack([k, v], dim=2)
+        dq_pad_fn = output_pad_fn
+        if key_padding_mask is not None:
+            dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
+        else:
+            dkv_pad_fn = lambda dkv_unpad: rearrange(
+                dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
+            )
+        return (
+            q_unpad.detach().requires_grad_(),
+            kv_unpad.detach().requires_grad_(),
+            cu_seqlens_q,
+            cu_seqlens_k,
+            max_seqlen_q,
+            max_seqlen_k,
+            q.detach().requires_grad_(),
+            kv.detach().requires_grad_(),
+            output_pad_fn,
+            dq_pad_fn,
+            dkv_pad_fn,
+        )
+    else:
+        dq_pad_fn = output_pad_fn
+        if key_padding_mask is not None:
+            dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
+        else:
+            dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
+        return (
+            q_unpad.detach().requires_grad_(),
+            k_unpad.detach().requires_grad_(),
+            v_unpad.detach().requires_grad_(),
+            cu_seqlens_q,
+            cu_seqlens_k,
+            max_seqlen_q,
+            max_seqlen_k,
+            q.detach().requires_grad_(),
+            k.detach().requires_grad_(),
+            v.detach().requires_grad_(),
+            output_pad_fn,
+            dq_pad_fn,
+            dk_pad_fn,
+        )
+
+
+def construct_local_mask(
+    seqlen_q,
+    seqlen_k,
+    window_size=(-1, -1),  # -1 means infinite window size
+    query_padding_mask=None,
+    key_padding_mask=None,
+    device=None,
+    key_leftpad=None,
+):
+    row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
+    col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
+    if key_leftpad is not None:
+        key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
+        col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
+        col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
+    sk = (
+        seqlen_k
+        if key_padding_mask is None
+        else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
+    )
+    sq = (
+        seqlen_q
+        if query_padding_mask is None
+        else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
+    )
+    if window_size[0] < 0:
+        return col_idx > row_idx + sk - sq + window_size[1]
+    else:
+        sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
+        return torch.logical_or(
+            col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
+            col_idx < row_idx + sk - sq - window_size[0],
+        )
+
+
+def attention_ref(
+    q,
+    k,
+    v,
+    query_padding_mask=None,
+    key_padding_mask=None,
+    attn_bias=None,
+    dropout_p=0.0,
+    dropout_mask=None,
+    causal=False,
+    window_size=(-1, -1),  # -1 means infinite window size
+    softcap=0.0,
+    upcast=True,
+    reorder_ops=False,
+    key_leftpad=None,
+):
+    """
+    Arguments:
+        q: (batch_size, seqlen_q, nheads, head_dim)
+        k: (batch_size, seqlen_k, nheads_k, head_dim)
+        v: (batch_size, seqlen_k, nheads_k, head_dim)
+        query_padding_mask: (batch_size, seqlen_q)
+        key_padding_mask: (batch_size, seqlen_k)
+        attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
+        dropout_p: float
+        dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
+        causal: whether to apply causal masking
+        window_size: (int, int), left and right window size
+        upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
+            output back to fp16/bf16.
+        reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
+            without changing the math. This is to estimate the numerical error from operation
+            reordering.
+    Output:
+        output: (batch_size, seqlen_q, nheads, head_dim)
+        attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
+    """
+    if causal:
+        window_size = (window_size[0], 0)
+    dtype_og = q.dtype
+    if upcast:
+        q, k, v = q.float(), k.float(), v.float()
+    seqlen_q, seqlen_k = q.shape[1], k.shape[1]
+    k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
+    v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
+    d = q.shape[-1]
+    if not reorder_ops:
+        scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
+    else:
+        scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
+    if softcap > 0:
+        scores /= softcap
+        scores = scores.tanh()
+        scores *= softcap
+    if key_padding_mask is not None:
+        scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
+    if window_size[0] >= 0 or window_size[1] >= 0:
+        local_mask = construct_local_mask(
+            seqlen_q,
+            seqlen_k,
+            window_size,
+            query_padding_mask,
+            key_padding_mask,
+            q.device,
+            key_leftpad=key_leftpad,
+        )
+        scores.masked_fill_(local_mask, float("-inf"))
+    if attn_bias is not None:
+        scores = scores + attn_bias
+    attention = torch.softmax(scores, dim=-1).to(v.dtype)
+    # Some rows might be completely masked out so we fill them with zero instead of NaN
+    if window_size[0] >= 0 or window_size[1] >= 0:
+        attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
+    # We want to mask here so that the attention matrix doesn't have any NaNs
+    # Otherwise we'll get NaN in dV
+    if query_padding_mask is not None:
+        attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
+    dropout_scaling = 1.0 / (1 - dropout_p)
+    # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
+    # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
+    if dropout_mask is not None:
+        attention_drop = attention.masked_fill(~dropout_mask, 0.0)
+    else:
+        attention_drop = attention
+    output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
+    if query_padding_mask is not None:
+        output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
+    return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)