/****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include #include // For FastDivMod #include "cute/tensor.hpp" #include "cutlass/gemm/collective/builders/sm90_common.inl" #include "cutlass/epilogue/collective/builders/sm90_common.inl" #include "seqlen.h" #include "named_barrier.hpp" #include "pack_gqa.h" #include "utils.h" namespace flash { using namespace cute; template struct CollectiveEpilogueFwd { using TileShape_MNK = TileShape_MNK_; using ClusterShape = ClusterShape_; using Element = Element_; static constexpr int NumEpilogueThreads = NumEpilogueThreads_; static constexpr bool Varlen = Varlen_; static constexpr bool PackGQA = PackGQA_; static constexpr bool Use_smem = sizeof(Element) <= 2; static constexpr bool Use_TMA_O = !Varlen && Use_smem && !PackGQA; static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; // These are for storing the output tensor without TMA (e.g., for setting output to zero) static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore"); // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements // in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times // we need to call divmod. static constexpr int kBytePerRow = kHeadDim * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); // static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); // static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerStore, NumEpilogueThreads); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore; // If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0); static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumEpilogueThreads / kGmemThreadsPerRow"); using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout>>{})); // Val layout, 8 or 16 vals per store using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); using ShapeO = cute::Shape; // (seqlen_q, d, head, batch, num_splits) using StrideO = cute::Stride; using StrideLSE = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen_q, head, batch, num_splits) // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits) using ShapeOPacked = std::conditional_t, int32_t, int32_t, int32_t, int32_t>>; using StrideOPacked = std::conditional_t, _1, int64_t, int64_t, int64_t>>; // ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits) using ShapeLSEPacked = std::conditional_t, cute::Shape, int32_t, int32_t, int32_t>>; using StrideLSEPacked = std::conditional_t, int64_t, int64_t, int64_t>>; // cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16) using CopyOpR2S = decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator()); using SmemCopyAtomO = Copy_Atom; // static constexpr size_t SmemAlignmentO = cutlass::detail::alignment_for_swizzle(SmemLayoutO{}); // static_assert(SmemAlignmentO >= 128, "Require at least 128B alignment"); // struct TensorStorage : cute::aligned_struct { // cute::array_aligned : 0, SmemAlignmentO> smem_o; // }; struct TensorStorage : cute::aligned_struct<128> { cute::array_aligned : 0> smem_o; }; using TMA_O = decltype(make_tma_copy( GmemTiledCopyOTMA{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeO{}, StrideO{}), SmemLayoutO{}, select<0, 2>(TileShape_MNK{}), _1{})); // no mcast for O // Host side kernel arguments struct Arguments { Element* ptr_O; ShapeO const shape_O; StrideO const stride_O; float* ptr_LSE; StrideLSE const stride_LSE; int32_t const nheads_kv; int const* cu_seqlens = nullptr; int const* seqused = nullptr; }; // Device side kernel params struct Params { Element* ptr_O; ShapeO const shape_O; StrideO const stride_O; ShapeOPacked const shape_O_packed; StrideOPacked const stride_O_packed; float* ptr_LSE; StrideLSE const stride_LSE; ShapeLSEPacked const shape_LSE_packed; StrideLSEPacked const stride_LSE_packed; cutlass::FastDivmod qhead_per_khead_divmod; TMA_O tma_store_O; int const* cu_seqlens = nullptr; int const* seqused = nullptr; }; static Params to_underlying_arguments(Arguments const& args) { Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O); TMA_O tma_store_O = make_tma_copy( GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 2>(TileShape_MNK{}), _1{}); // no mcast for O // If PackGQA, reshape O to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size, num_splits) int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_O), args.nheads_kv); auto const shape_O_packed = cute::conditional_return( args.shape_O, make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O)) ); auto const stride_O_packed = cute::conditional_return( args.stride_O, make_stride(make_stride(get<2>(args.stride_O), get<0>(args.stride_O)), get<1>(args.stride_O), get<2>(args.stride_O) * qhead_per_khead, get<3>(args.stride_O), get<4>(args.stride_O)) ); // If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits) auto const shape_LSE_packed = cute::conditional_return( select<0, 2, 3, 4>(args.shape_O), make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O)) ); auto const stride_LSE_packed = cute::conditional_return( args.stride_LSE, make_stride(make_stride(get<1>(args.stride_LSE), get<0>(args.stride_LSE)), get<1>(args.stride_LSE) * qhead_per_khead, get<2>(args.stride_LSE), get<3>(args.stride_LSE)) ); return {args.ptr_O, args.shape_O, args.stride_O, shape_O_packed, stride_O_packed, args.ptr_LSE, args.stride_LSE, shape_LSE_packed, stride_LSE_packed, cutlass::FastDivmod(qhead_per_khead), tma_store_O, args.cu_seqlens, args.seqused}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& params) { if constexpr (Use_TMA_O) { cute::prefetch_tma_descriptor(params.tma_store_O.get_tma_descriptor()); } } template CUTLASS_DEVICE void store(Params const& params, FrgTensorO const& tOrO, FrgTensorLSE const& lse, SharedStorage& shared_storage, TiledMma tiled_mma, int thread_idx, cute::tuple const& block_coord ) { auto [m_block, bidh, bidb, split_idx] = block_coord; Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{}); // Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO); Tensor tOrO_out = make_tensor_like(tOrO); flash::convert_type_out(tOrO, tOrO_out); if constexpr (FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4)) { flash::permute_output_fp8_Vcolmajor(tOrO_out); } // Make sure all WGs have finished reading V // Technically we don't need this if we're not using smem, but the mainloop makes the assumption that // all epilogue threads sync at least once during the epilogue (so that we can start loading Q with // cp.async if we need). cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); // Step 1: Write O from rmem -> smem if constexpr (Use_smem) { auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma); auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) // Tensor taccOsO = smem_thr_copy_O.partition_D(sO_pi); // ((Atom,AtomNum),PIPE_M,PIPE_N) cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); if constexpr (!Varlen && !PackGQA) { cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); } else { cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); } } else { #pragma unroll for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { shared_storage.pipelines.barrier_O.arrive(cta_id); } } flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; bool is_varlen = Varlen && params.cu_seqlens; int offset_o = seqlen_info.offset; int seqlen_o = seqlen_info.seqlen; // Step 2: Write LSE from rmem -> gmem auto thread_mma = tiled_mma.get_thread_slice(thread_idx); // (MMA,MMA_M,MMA_K) Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}))); static_assert(decltype(size<0, 0>(taccOcO))::value == 2); static_assert(decltype(size<0, 1>(taccOcO))::value == 2); // taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices. Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{}); CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumEpilogueThreads, Element>; Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset_o * get<0>(params.stride_LSE)), params.shape_LSE_packed, params.stride_LSE_packed)(_, bidh, !is_varlen ? bidb : 0, split_idx); // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); } if constexpr (!PackGQA) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { int const row = m_block * kBlockM + get<0>(taccOcO_row(mi)); if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); } } } else { PackGQAt::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } // Step 3: Write O from smem -> gmem if constexpr (Use_TMA_O) { Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx); Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) auto block_tma_O = params.tma_store_O.get_slice(_0{}); Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K) Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K) int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0); if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) { cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); if (cute::elect_one_sync()) { cute::copy(params.tma_store_O, tOsO, tOgO); tma_store_arrive(); tma_store_wait<0>(); #pragma unroll for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { shared_storage.pipelines.barrier_O.arrive(cta_id); } } } } else { // Don't use TMA since we don't want to overwrite the output of another sequence Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast(&mO(0)) - reinterpret_cast(params.ptr_O)); } if constexpr (Use_smem) { GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) // Tensor tOsO = gmem_thr_copy_O.partition_S(sO_pi); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tOrO = make_fragment_like(tOsO); cute::copy(gmem_tiled_copy_O, tOsO, tOrO); cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_v #pragma unroll for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { shared_storage.pipelines.barrier_O.arrive(cta_id); } if constexpr (!PackGQA) { // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}))); Tensor tOpO = make_tensor(make_shape(size<2>(tOsO))); #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } Tensor tOgO = gmem_thr_copy_O.partition_D(gO); // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM ); } else { // If PackGQA, we split the work of compute O_ptr among threads in the same row PackGQAt::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } } else { // We already arrived on barrier_O earlier if constexpr (!PackGQA) { static constexpr int kGmemElemsPerStoreDirect = 2; cute::Copy_Atom, Element> gmem_copy_direct; // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) Tensor tOrO_rowcol = make_tensor(tOrO_out.data(), flash::convert_layout_acc_rowcol(tOrO.layout())); Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int>{}); Tensor tOgO = thread_mma.partition_C(gO); Tensor tOgO_rowcol = make_tensor(tOgO.data(), flash::convert_layout_acc_rowcol(tOgO.layout())); Tensor tOgO_copy = cute::tiled_divide(tOgO_rowcol, Shape<_1, Int>{}); // taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the col indices. Tensor taccOcO_col = taccOcO(make_coord(_, _0{}, _), _0{}, _); #pragma unroll for (int m = 0; m < size(taccOcO_row); ++m) { if (get<0>(taccOcO_row(m)) < seqlen_o - m_block * kBlockM) { #pragma unroll for (int k = 0; k < size(taccOcO_col) / kGmemElemsPerStoreDirect; ++k) { if (get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)) < get<1>(params.shape_O)) { cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), tOgO_copy(_, m, k)); } } } } } else { PackGQAt::store_O_direct(mO, tOrO_out, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } } } } CUTLASS_DEVICE void store_tail() { // Don't need to do tma_store_wait<0>() here since we already did in @store } // Write 0 to output and -inf to LSE template CUTLASS_DEVICE void store_zero( Params const& params, int thread_idx, cute::tuple const& block_coord ) { static constexpr int kBlockM = get<0>(TileShape_MNK{}); auto [m_block, bidh, bidb, split_idx] = block_coord; flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; bool const is_varlen = Varlen && params.cu_seqlens; int offset_o = seqlen_info.offset; int seqlen_o = seqlen_info.seqlen; int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset_o * get<0>(params.stride_LSE)), params.shape_LSE_packed, params.stride_LSE_packed)(_, bidh, !is_varlen ? bidb : 0, split_idx); Tensor gLSE = local_tile(mLSE, Shape>{}, make_coord(m_block)); static_assert(kBlockM <= NumEpilogueThreads); if (thread_idx < kBlockM) { const int row = m_block * kBlockM + thread_idx; if constexpr (!PackGQA) { if (row < seqlen_o) { mLSE(row) = -INFINITY; } } else { if (row < seqlen_o * qhead_per_khead) { int m_idx, h_idx; m_idx = params.qhead_per_khead_divmod.divmod(h_idx, row); // mLSE shape shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord" mLSE(make_coord(make_coord(h_idx, m_idx))) = -INFINITY; } } } if constexpr (!Clear_O) { return; } GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}))); if constexpr (!PackGQA) { Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) Tensor tOgO = gmem_thr_copy_O.partition_D(gO); Tensor tOrO = make_fragment_like(tOgO); cute::clear(tOrO); // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM ); } else { // If PackGQA, we split the work of compute O_ptr among threads in the same row using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumEpilogueThreads, Element>; Tensor tOrO = make_tensor(make_shape(Shape<_1, Int>{}, size<1>(tOcO), size<2>(tOcO))); cute::clear(tOrO); PackGQAt::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } } }; } // namespace flash