/****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include #include #include #include #include #include "cutlass/pipeline/pipeline.hpp" #include "cute/tensor.hpp" #include "cutlass/gemm/collective/builders/sm90_common.inl" #include "named_barrier.hpp" #include "seqlen.h" #include "mask.h" #include "softmax.h" #include "utils.h" #include "copy_sm90_bulk_reduce.hpp" namespace flash { template CUTLASS_DEVICE auto mma_partition_fragment_AB(Mma const& mma, Tensor0 const& tensor0) { if constexpr (A) { return mma.partition_fragment_A(tensor0); } else { return mma.partition_fragment_B(tensor0); } } using namespace cute; template struct CollectiveMainloopBwd { static constexpr int kStages = Stages; static constexpr int kStages_dO = Stages_dO; static constexpr int kStages_dS = Stages_dS; static_assert(kStages >= kStages_dO); static_assert(Stages_dS == 1 || Stages_dS == kStages); using ClusterShape = ClusterShape_; using TileShape_MNK = TileShape_MNK_; using Element = Element_; using ElementAccum = ElementAccum_; using ArchTag = ArchTag_; static constexpr bool Is_causal = Is_causal_; static constexpr bool Is_local = Is_local_; static constexpr bool Has_softcap = Has_softcap_; static constexpr bool Varlen = Varlen_; using SeqlenInfo_t = flash::SeqlenInfoQK(TileShape_MNK{}))>; static constexpr bool SdP_swapAB = SdP_swapAB_; static constexpr bool dKV_swapAB = dKV_swapAB_; static constexpr bool dQ_swapAB = dQ_swapAB_; static constexpr bool Q_dO_same_stages = kStages == kStages_dO; static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); static_assert(ArchTag::kMinComputeCapability >= 90); static_assert(get<0>(ClusterShape{}) == 1 && get<2>(ClusterShape{}) == 1); static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; static_assert(NumMmaWarpGroups % AtomLayoutMSdP == 0); static_assert(NumMmaWarpGroups % AtomLayoutNdKV == 0); static_assert(NumMmaWarpGroups % AtomLayoutMdQ == 0); static constexpr bool Mma_dKV_is_RS = AtomLayoutMSdP == 1 && AtomLayoutNdKV == NumMmaWarpGroups && SdP_swapAB && !dKV_swapAB; static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == NumMmaWarpGroups && AtomLayoutMdQ == NumMmaWarpGroups && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS static constexpr GMMA::Major PdS_Major = GMMA::Major::K; // static constexpr GMMA::Major PdS_Major = GMMA::Major::MN; static constexpr GMMA::Major PdSt_Major = PdS_Major == GMMA::Major::K ? GMMA::Major::MN : GMMA::Major::K; using TileShapeAtomSdP = std::conditional_t< !SdP_swapAB, Shape, Int, Int>, Shape, Int, Int> >; using AtomLayoutSdP = std::conditional_t< !SdP_swapAB, Layout, Int, _1>>, Layout, Int, _1>> >; using TiledMmaSdP = decltype(cute::make_tiled_mma( cute::GMMA::ss_op_selector(), AtomLayoutSdP{})); using TiledMmadPRS = decltype(cute::make_tiled_mma( cute::GMMA::rs_op_selector(), AtomLayoutSdP{})); using TileShapeAtomdKV = std::conditional_t< !dKV_swapAB, Shape, Int, Int>, Shape, Int, Int> >; using AtomLayoutdKV = std::conditional_t< !dKV_swapAB, Layout, Int, _1>>, Layout, Int, _1>> >; using TiledMmadKV = decltype(cute::make_tiled_mma( std::conditional_t< Mma_dKV_is_RS, decltype(cute::GMMA::rs_op_selector()), decltype(cute::GMMA::ss_op_selector()) >{}, AtomLayoutdKV{})); using TileShapeAtomdQ = std::conditional_t< !dQ_swapAB, Shape, Int, Int>, Shape, Int, Int> >; using AtomLayoutdQ = std::conditional_t< !dQ_swapAB, Layout, Int, _1>>, Layout, Int, _1>> >; using TiledMmadQ = decltype(cute::make_tiled_mma( std::conditional_t< Mma_dQ_is_RS, decltype(cute::GMMA::rs_op_selector()), decltype(cute::GMMA::ss_op_selector()) >{}, AtomLayoutdQ{})); // We need to accommodate both Q and Q^T (and dO and dO^T) in shared memory. // Q & dO are used in the SdP Mma and Q^T and dO^T are used in the dKV Mma. // Since this is GMMA::Major::K, the M dimension (kBlockM) doesn't matter for the layout, only the K dimension // changes the layout. using SmemLayoutAtomQdO = decltype(cutlass::gemm::collective::detail::ss_smem_selector, Int>()); // for dKV_Mma using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQdO{}, make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutdO = decltype(tile_to_shape(SmemLayoutAtomQdO{}, make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector, Int>()); using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{}))); using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{}))); using SmemLayoutAtomPdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector, Int>()); using SmemLayoutPdS = decltype(tile_to_shape( SmemLayoutAtomPdS{}, make_shape(Int{}, Int{}, Int{}), std::conditional_t, cute::Step<_2, _1, _3>>{})); // Need stride to be multiple of 32, otherwise we get error (misaligned address) when doing TMA if e.g. kBlockM=80 // We set stride to be multiple of 64 so that if ShuffleLSE, even if threads read from sLSE but out of bounds, // it's still a valid smem address. using SmemLayoutLSE = cute::Layout, Int>, cute::Stride<_1, Int>>; using SmemLayoutLSEMma = std::conditional_t< SdP_swapAB, cute::Layout, Int, Int>, cute::Stride<_0, _1, Int>>, cute::Layout, Int, Int>, cute::Stride<_1, _0, Int>> >; // Note this is the transpose in terms of the view, not in terms of memory. using SmemLayoutQt = decltype(cute::composition(SmemLayoutQ{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), make_stride(Int{}, _1{}, Int{})))); using SmemLayoutdOt = decltype(cute::composition(SmemLayoutdO{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), make_stride(Int{}, _1{}, Int{})))); using SmemLayoutKt = decltype(cute::composition(SmemLayoutK{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), make_stride(Int{}, _1{})))); using SmemLayoutPdSt = decltype(cute::composition(SmemLayoutPdS{}, make_layout(make_shape(Int{}, Int{}, Int{}), make_stride(Int{}, _1{}, Int{})))); // Thread layout, 256 or 384 threads per row // We split into NumMmaWarpGroups so that we can do Bulk reduce add for each WG separately. using R2SLayoutAtomdQaccum = Layout, Int>>; using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, R2SLayoutAtomdQaccum{}, Layout>{})); // Val layout, 4 vals per store using SmemLayoutdQaccum = Layout, Int>>; static constexpr int kNumPdSStore = kBlockM * kBlockN / NumMmaThreads; // If !SdP_swapAB, the accum registers hold P / dS, otherwise they hold Pt / dSt. // If PdS_major is MN, then we need to "transpose" the write. using SmemCopyAtomPdS = Copy_Atom< std::conditional_t<(!SdP_swapAB) ^ (PdS_Major == GMMA::Major::MN), std::conditional_t, std::conditional_t >, Element >; using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape{}))); using GmemTiledCopyKV = cute::SM90_TMA_LOAD; using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) using StrideQKV = cute::Stride; using ShapeLSE = cute::Shape; // (seqlen, head, batch) using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch) using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; using TMA_QdO = decltype(make_tma_copy_A_sm90( GmemTiledCopyQdO{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), take<0, 2>(SmemLayoutQ{}), TileShape_MNK{}, ClusterShape{})); // mcast along N mode for this M load, if any using TMA_K = decltype(make_tma_copy_B_sm90( GmemTiledCopyKV{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), SmemLayoutK{}, TileShape_MNK{}, ClusterShape{})); // no mcast for KV using TMA_V = decltype(make_tma_copy_B_sm90( GmemTiledCopyKV{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), SmemLayoutV{}, TileShape_MNK{}, ClusterShape{})); // no mcast for KV using MainloopPipeline = typename cutlass::PipelineTmaAsync; using PipelineState = typename MainloopPipeline::PipelineState; using MainloopPipeline_dO = typename cutlass::PipelineTmaAsync; using PipelineState_dO = typename MainloopPipeline_dO::PipelineState; // Set the bytes transferred in this TMA transaction (may involve multiple issues) static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(take<0, 2>(SmemLayoutQ{})) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesK = static_cast(size(SmemLayoutK{}) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesV = static_cast(size(SmemLayoutV{}) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesLSE = static_cast(size(select<0>(SmemLayoutLSE{})) * cutlass::sizeof_bits_v / 8); // These are tuned for speed. They don't affect correctness. // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64 // this helps quite a bit to not have to do causal masking for most of the iterations. // For hdim 192, separating masking iterations results in register spills. static constexpr bool SeparateMaskingIterations = kHeadDim <= 64; // Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share them and then // shuffle to get the value whenever we need? This can reduce register pressure when SdP_swapAB, where each // thread needs to keep statistics for (kBlockM / 4) rows. If !SdP_swapAB, each thread only needs to keep // statistic for 2 rows. static constexpr bool ShuffleLSE = SdP_swapAB && kHeadDim <= 64; static constexpr bool ShuffledPSum = SdP_swapAB && kHeadDim <= 64; // If we have extra registers, we can keep V in registers to reduce smem traffic. static constexpr bool Mma_dP_is_RS = SdP_swapAB && kHeadDim == 96; static constexpr bool dQacc_use_TMA = kHeadDim < 256; // For hdim256, we want to slice the dQ MMA (64 x 256 on 2 WGs) into two (64 x 128 on 2 WGs) so that we can // do atomic add on one half before doing the other half of the MMA, to reduce register pressure. static constexpr bool Slice_dQKV_Mma = kHeadDim == 256 && !dQacc_use_TMA && dQ_swapAB && AtomLayoutMdQ == 1 && NumMmaWarpGroups == 2; static_assert(!(Deterministic && Slice_dQKV_Mma), "Deterministic mode not supported with Slice_dQKV_Mma"); static constexpr size_t SmemAlignmentP = cutlass::detail::alignment_for_swizzle(SmemLayoutPdS{}); static constexpr size_t SmemAlignmentdS = cutlass::detail::alignment_for_swizzle(SmemLayoutPdS{}); // Without this SmemAlignment, with hdim 256 we get "misaligned address" error in TMA static constexpr size_t SmemAlignmentQKVdO = kHeadDim % 256 == 0 ? 256 : 128; static constexpr size_t SmemAlignmentV = !Mma_dP_is_RS ? SmemAlignmentQKVdO : cutlass::detail::alignment_for_swizzle(SmemLayoutV{}); static_assert(SmemAlignmentP >= 128 && SmemAlignmentdS >= 128, "Require at least 128B alignment"); // TODO: do we have to worry that smem_dk and smem_dv in the epilogue don't line up w smem_k and smem_v due to alignment? using SmemdQacc_t = std::conditional_t, cute::array_aligned>>; using SmemP_t = std::conditional_t, cute::array_aligned, SmemAlignmentP>>; struct TensorStorage : cute::aligned_struct { cute::array_aligned, SmemAlignmentQKVdO> smem_k; cute::array_aligned, SmemAlignmentV> smem_v; SmemdQacc_t smem_dqacc; cute::array_aligned, SmemAlignmentQKVdO> smem_q; cute::array_aligned, SmemAlignmentQKVdO> smem_do; cute::array_aligned, 128> smem_lse; cute::array_aligned, 128> smem_dpsum; SmemP_t smem_p; cute::array_aligned, SmemAlignmentdS> smem_ds; }; // Host side kernel arguments struct Arguments { Element const* ptr_Q; ShapeQKV const shape_Q; StrideQKV const stride_Q; Element const* ptr_K; ShapeQKV const shape_K; StrideQKV const stride_K; Element const* ptr_V; StrideQKV const stride_V; Element const* ptr_dO; StrideQKV const stride_dO; ElementAccum* ptr_dQaccum; ShapedQaccum const shape_dQaccum; StridedQaccum const stride_dQaccum; float const* ptr_LSE_log2; ShapeLSE const shape_LSE; StrideLSE const stride_LSE_log2; float const* ptr_dPsum; StrideLSE const stride_dPsum; float const softmax_scale; int const window_size_left, window_size_right, sink_token_length; float const softcap_val; int const num_batch; int* const dq_semaphore; int const* const cu_seqlens_q = nullptr; int const* const cu_seqlens_k = nullptr; int const* const seqused_q = nullptr; int const* const seqused_k = nullptr; }; // Device side kernel params struct Params { ShapeQKV const shape_Q; ShapeQKV const shape_K; ShapedQaccum const shape_dQaccum; ElementAccum* ptr_dQaccum; StridedQaccum stride_dQaccum; cutlass::FastDivmod qhead_per_khead_divmod; TMA_QdO tma_load_Q, tma_load_dO; TMA_K tma_load_K; TMA_V tma_load_V; float const* ptr_LSE_log2; ShapeLSE const shape_LSE; StrideLSE const stride_LSE_log2; float const* ptr_dPsum; StrideLSE const stride_dPsum; float const softmax_scale, softmax_scale_log2; int const window_size_left, window_size_right, sink_token_length; float const softcap_val; int const num_batch; int* const dq_semaphore; int const* const cu_seqlens_q = nullptr; int const* const cu_seqlens_k = nullptr; int const* const seqused_q = nullptr; int const* const seqused_k = nullptr; }; static Params to_underlying_arguments(Arguments const& args) { Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q); TMA_QdO tma_load_Q = make_tma_copy_A_sm90( GmemTiledCopyQdO{}, mQ, SmemLayoutQ{}(_, _, _0{}), TileShape_MNK{}, ClusterShape{}); // mcast along N mode for this M load, if any Tensor mdO = make_tensor(make_gmem_ptr(args.ptr_dO), args.shape_Q, args.stride_dO); TMA_QdO tma_load_dO = make_tma_copy_A_sm90( GmemTiledCopyQdO{}, mdO, SmemLayoutdO{}(_, _, _0{}), TileShape_MNK{}, ClusterShape{}); // mcast along N mode for this M load, if any Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K); TMA_K tma_load_K = make_tma_copy_B_sm90( GmemTiledCopyKV{}, mK, SmemLayoutK{}, TileShape_MNK{}, ClusterShape{}); // no mcast for KV Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_K, args.stride_V); TMA_V tma_load_V = make_tma_copy_B_sm90( GmemTiledCopyKV{}, mV, SmemLayoutV{}, TileShape_MNK{}, ClusterShape{}); // no mcast for KV if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); } // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. // Right after this, we multiply by log2(e) before applying exp2. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) // (assigning it to params.softmax_scale_log2). // In the backward, we need to multiply by // (1 - tanh^2) * softmax_scale / softcap_val * softcap_val = (1 - tanh^2) * softmax_scale. // Instead we multiply by (1 - tanh^2) and multiply dK and dV by params.softmax_scale // (the original softmax_scale) at the end. return {args.shape_Q, args.shape_K, args.shape_dQaccum, args.ptr_dQaccum, args.stride_dQaccum, cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), tma_load_Q, tma_load_dO, tma_load_K, tma_load_V, args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, args.softmax_scale, !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), args.window_size_left, args.window_size_right, args.sink_token_length, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& params) { cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_load_dO.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor()); } CUTLASS_DEVICE cute::tuple get_m_block_min_max(Params const& params, SeqlenInfo_t const& seqlen_info, int n_block, int bidb) { static constexpr int kBlockM = get<0>(TileShape_MNK{}); int const seqlen_q = seqlen_info.seqlen_q; int const seqlen_k = seqlen_info.seqlen_k; int m_block_max = cute::ceil_div(seqlen_q, kBlockM); if constexpr (Is_local) { static constexpr int kBlockN = get<1>(TileShape_MNK{}); if (n_block >= cute::ceil_div(params.sink_token_length, kBlockN)) { m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + params.window_size_left, kBlockM)); } } int m_block_min = 0; if constexpr (Is_causal || Is_local) { m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - params.window_size_right) / kBlockM); } return {m_block_min, m_block_max}; } template CUTLASS_DEVICE void load(Params const& params, MainloopPipeline pipeline_q, MainloopPipeline_dO pipeline_do, PipelineState& smem_pipe_write, PipelineState_dO& smem_pipe_write_do, SharedStorage &shared_storage, SchedulerPrefetch const& scheduler_prefetch, cute::tuple block_coord ) { auto [n_block, bidh, bidb] = block_coord; SeqlenInfo_t seqlen_info{ bidb, get<0>(params.shape_Q), size<0>(params.shape_K), params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k }; auto [m_block_min, m_block_max] = get_m_block_min_max(params, seqlen_info, n_block, bidb); // It's possible to have m_block_max <= m_block_min. Loading Q, K can cause illegal memory access. if constexpr (Is_causal || Is_local || Varlen) { if (m_block_max <= m_block_min) { scheduler_prefetch(); return; } } Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); Tensor sdO = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdO{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSE{}); Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSE{}); int bidh_kv = params.qhead_per_khead_divmod.divide(bidh); // Prepare the TMA loads uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; bool const is_varlen_q = Varlen && params.cu_seqlens_q; bool const is_varlen_k = Varlen && params.cu_seqlens_k; Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mdO = params.tma_load_dO.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mK = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, bidh, !is_varlen_q ? bidb : 0); Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, bidh, !is_varlen_q ? bidb : 0); Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) Tensor gdO = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) Tensor gK = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) Tensor gV = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) Tensor gLSE = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mLSE), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _) Tensor gdPsum = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mdPsum), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _) Tensor sK_x = make_tensor(sK.data(), make_layout(sK.layout(), Layout<_1>{})); Tensor gK_x = make_tensor(gK.data(), make_layout(gK.layout(), Layout<_1>{})); Tensor sV_x = make_tensor(sV.data(), make_layout(sV.layout(), Layout<_1>{})); Tensor gV_x = make_tensor(gV.data(), make_layout(gV.layout(), Layout<_1>{})); // auto [tQgQ, tQsQ] = tma_partition(params.tma_load_Q, block_rank_in_cluster, Layout{}, // group_modes<0, 2>(sQ), group_modes<0, 2>(gQ)); // (TMA, k), (TMA, PIPE) // auto [tdOgdO, tdOsdO] = tma_partition(params.tma_load_dO, block_rank_in_cluster, Layout{}, // group_modes<0, 2>(sdO), group_modes<0, 2>(gdO)); // (TMA, k), (TMA, PIPE) auto block_tma_Q = params.tma_load_Q.get_slice(cluster_local_block_id.y); auto block_tma_dO = params.tma_load_dO.get_slice(cluster_local_block_id.y); Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); Tensor tdOgdO = group_modes<0, 3>(block_tma_dO.partition_S(gdO)); Tensor tdOsdO = group_modes<0, 3>(block_tma_dO.partition_D(sdO)); auto [tKgK, tKsK] = tma_partition(params.tma_load_K, _0{}, Layout<_1>{}, group_modes<0, 2>(sK_x), group_modes<0, 2>(gK_x)); // (TMA), (TMA) auto [tVgV, tVsV] = tma_partition(params.tma_load_V, _0{}, Layout<_1>{}, group_modes<0, 2>(sV_x), group_modes<0, 2>(gV_x)); // (TMA), (TMA) auto bulk_copy = Copy_Traits{}; uint16_t mcast_mask_qdo = 0; if constexpr (cute::is_same_v) { auto block_layout = Layout{}; // (m,n) -> block_id for (int n = 0; n < size<1>(block_layout); ++n) { mcast_mask_qdo |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, _0{})); } } int m_block = m_block_min; int lane_predicate = cute::elect_one_sync(); if (lane_predicate) { pipeline_q.producer_acquire(smem_pipe_write); copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST), tQgQ(_, m_block), tQsQ(_, smem_pipe_write.index())); copy(bulk_copy.with(*pipeline_q.producer_get_barrier(smem_pipe_write)), gLSE(_, m_block), sLSE(_, smem_pipe_write.index())); } // // Wait for the MMA warpgroups to say that smem_k and smem_v are ready // cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); if (lane_predicate) { // Copy K tile and V tile from GMEM to SMEM. shared_storage.pipelines.barrier_KV.arrive_and_expect_tx(TmaTransactionBytesK + TmaTransactionBytesV); copy(params.tma_load_K.with(reinterpret_cast(shared_storage.pipelines.barrier_KV), 0 /*mcast_mask*/), tKgK, tKsK); copy(params.tma_load_V.with(reinterpret_cast(shared_storage.pipelines.barrier_KV), 0 /*mcast_mask*/), tVgV, tVsV); #pragma unroll (kHeadDim < 256 ? 2 : 1) for (; m_block < m_block_max - 1; ++m_block) { // If Q and dO have the same number of stages, we can use the same pipeline state variable // to reduce registers PipelineState_dO smem_pipe_write_do_cur = cute::conditional_return(smem_pipe_write, smem_pipe_write_do); pipeline_do.producer_acquire(smem_pipe_write_do_cur); copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST), tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write_do_cur.index())); copy(bulk_copy.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur)), gdPsum(_, m_block), sdPsum(_, smem_pipe_write_do_cur.index())); if constexpr (!Q_dO_same_stages) { ++smem_pipe_write_do; } ++smem_pipe_write; pipeline_q.producer_acquire(smem_pipe_write); copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST), tQgQ(_, m_block + 1), tQsQ(_, smem_pipe_write.index())); copy(bulk_copy.with(*pipeline_q.producer_get_barrier(smem_pipe_write)), gLSE(_, m_block + 1), sLSE(_, smem_pipe_write.index())); } } scheduler_prefetch(); if (lane_predicate) { PipelineState_dO smem_pipe_write_do_cur = cute::conditional_return(smem_pipe_write, smem_pipe_write_do); pipeline_do.producer_acquire(smem_pipe_write_do_cur); copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST), tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write_do_cur.index())); copy(bulk_copy.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur)), gdPsum(_, m_block), sdPsum(_, smem_pipe_write_do_cur.index())); if constexpr (!Q_dO_same_stages) { ++smem_pipe_write_do; } ++smem_pipe_write; } if constexpr (Q_dO_same_stages) { smem_pipe_write_do = smem_pipe_write; } } /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline_q, MainloopPipeline_dO pipeline_do, PipelineState& smem_pipe_write) { static_assert(Q_dO_same_stages, "Q and dO must have the same number of stages"); // Need to copy since pipeline_q.producer_tail(smem_pipe_write) will increment smem_pipe_write PipelineState smem_pipe_write_do = smem_pipe_write; // Issue the epilogue waits if (cute::elect_one_sync()) { /* This helps avoid early exit of blocks in Cluster * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used * then would just be acquired since the phase was still inverted from make_producer_start_state */ pipeline_q.producer_tail(smem_pipe_write); pipeline_do.producer_tail(smem_pipe_write_do); } } /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline_q, MainloopPipeline_dO pipeline_do, PipelineState& smem_pipe_write, PipelineState_dO& smem_pipe_write_do) { // Issue the epilogue waits if (cute::elect_one_sync()) { /* This helps avoid early exit of blocks in Cluster * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used * then would just be acquired since the phase was still inverted from make_producer_start_state */ pipeline_q.producer_tail(smem_pipe_write); pipeline_do.producer_tail(smem_pipe_write_do); } } template CUTLASS_DEVICE void store_dq(Params const& params, SharedStorage &shared_storage, cute::tuple block_coord ) { if constexpr (!dQacc_use_TMA) { return; } auto [n_block, bidh, bidb] = block_coord; SeqlenInfo_t seqlen_info{ bidb, get<0>(params.shape_Q), size<0>(params.shape_K), params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k }; auto [m_block_min, m_block_max] = get_m_block_min_max(params, seqlen_info, n_block, bidb); // It's possible to have m_block_max <= m_block_min. Exit early if constexpr (Is_causal || Is_local || Varlen) { if (m_block_max <= m_block_min) { return; } } Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dqacc.data()), SmemLayoutdQaccum{}); static constexpr int dQ_TMA_num_bytes = CUTE_STATIC_V(size<0>(sdQ)) * sizeof(ElementAccum); bool const is_varlen = Varlen && params.cu_seqlens_q; Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(_)); // (M * K, _) Tensor gdQaccum = cute::flat_divide(gdQaccum_, Int{}); // (M * K / WG, WG, _) int const num_batch = params.num_batch; int const num_head = get<2>(params.shape_Q); int *lock_ptr = !Deterministic ? nullptr : params.dq_semaphore + bidb * num_head + bidh; using Barrier = cutlass::GenericBarrier; bool const lane_predicate = cute::elect_one_sync(); int m_block = m_block_min; #pragma unroll 2 for (; m_block < m_block_max; ++m_block) { if constexpr (Deterministic) { Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block); } #pragma unroll for (int warpgroup_idx = 0; warpgroup_idx < NumMmaWarpGroups; ++warpgroup_idx) { cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQFullWG1) + warpgroup_idx /*id*/); // sdQ full, to be written to gmem if (lane_predicate) { SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdQ(_, warpgroup_idx).data()), raw_pointer_cast(gdQaccum(_, warpgroup_idx, m_block).data()), dQ_TMA_num_bytes, static_cast(TMA::CacheHintSm90::EVICT_LAST)); tma_store_arrive(); } } // Note, the for_each() function is required here to ensure `warpgroup_idx` is of type Int. for_each(make_int_sequence{}, [&] (auto warpgroup_idx) { if (lane_predicate) { tma_store_wait(); } cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmptyWG1) + warpgroup_idx /*id*/); // sdQ empty, ready to be written to }); if constexpr (Deterministic) { Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head); } } if constexpr (Is_local && Deterministic) { constexpr int kBlockM = get<0>(TileShape_MNK{}); int const m_block_global_max = cute::ceil_div(seqlen_info.seqlen_q, kBlockM); #pragma unroll 2 for (; m_block < m_block_global_max; ++m_block) { Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head); } } } CUTLASS_DEVICE void mma_init() { // Tell producer (warp 0) that smem_k and smem_v are ready // We're not currently using this bc we're not using persistent scheduler // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); if constexpr (dQacc_use_TMA) { if (warp_idx_in_warpgroup == 0) { cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmptyWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/); // sdQ empty, ready to be written to } } } template CUTLASS_DEVICE bool mma(Params const& params, MainloopPipeline pipeline_q, MainloopPipeline_dO pipeline_do, PipelineState& smem_pipe_read, PipelineState_dO& smem_pipe_read_do, FrgTensordKV& tdKrdK, FrgTensordKV& tdVrdV, int thread_idx, int &work_idx, cute::tuple block_coord, SharedStorage& shared_storage ) { static_assert(is_rmem::value, "dK and dV tensor must be rmem resident."); int n_block = get<0>(block_coord); int bidb = get<2>(block_coord); SeqlenInfo_t seqlen_info{ bidb, get<0>(params.shape_Q), size<0>(params.shape_K), params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k }; auto [m_block_min, m_block_max] = get_m_block_min_max(params, seqlen_info, n_block, bidb); // It's possible to have m_block_max <= m_block_min. Exit early if constexpr (Is_causal || Is_local || Varlen) { if (m_block_max <= m_block_min) { return false; } } Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); Tensor sdO = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdO{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); Tensor sQt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQt{}); Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdOt{}); Tensor sKt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutKt{}); Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdS{}); Tensor sP_pi = cute::as_position_independent_swizzle_tensor(sP); Tensor sPt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdSt{}); Tensor sPt_pi = cute::as_position_independent_swizzle_tensor(sPt); Tensor sdS = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdS{}); Tensor sdS_pi = cute::as_position_independent_swizzle_tensor(sdS); Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdSt{}); Tensor sdSt_pi = cute::as_position_independent_swizzle_tensor(sdSt); Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dqacc.data()), SmemLayoutdQaccum{}); Tensor sLSEMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSEMma{}); Tensor sdPsumMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSEMma{}); static_assert(stride<0>(typename TiledMmaSdP::ALayout{}) == 0 and stride<0>(typename TiledMmaSdP::BLayout{}) == 0 and size<0>(typename TiledMmaSdP::ALayout{}) == cutlass::NumThreadsPerWarpGroup and size<0>(typename TiledMmaSdP::BLayout{}) == cutlass::NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); constexpr int MmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup; Layout warp_group_thread_layout = make_layout(make_shape(Int{}), make_stride(Int{})); Layout warp_group_thread_layout_dq = make_layout(make_shape(Int{}), make_stride(Int{})); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); TiledMmaSdP tiled_mma_SdP; using TiledMmadP = std::conditional_t; TiledMmadP tiled_mma_dP; TiledMmadKV tiled_mma_dKV; TiledMmadQ tiled_mma_dQ; auto wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx)); auto wg_mma_dP = tiled_mma_dP.get_slice(warp_group_thread_layout(warp_group_idx)); auto thread_mma_SdP = tiled_mma_SdP.get_thread_slice(thread_idx); auto wg_mma_dKV = tiled_mma_dKV.get_slice(warp_group_thread_layout(warp_group_idx)); auto wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout_dq(warp_group_idx)); auto thread0_mma_SdP = tiled_mma_SdP.get_thread_slice(_0{}); auto smem_tiled_copy_PdS = make_tiled_copy_C(SmemCopyAtomPdS{}, tiled_mma_SdP); auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(thread_idx); R2STiledCopydQaccum r2s_tiled_copy_dQaccum; auto r2s_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_thread_slice(thread_idx); Tensor tdQsdQaccum = r2s_thr_copy_dQaccum.partition_D(sdQ); // if (thread_idx == 0) { print(sdQ); printf("\n"); print(tdQsdQaccum); printf("\n"); } // Allocate "fragments/descriptors" // We have to use the templated mma_partition_fragment_AB instead of cute::conditional_return or lambda, // because some partition_fragment_A/B don't compile. // https://stackoverflow.com/questions/50051473/if-constexpr-in-c17-does-not-work-in-a-non-templated-function Tensor tSrQ = mma_partition_fragment_AB(wg_mma_SdP, sQ); Tensor tSrK = mma_partition_fragment_AB(wg_mma_SdP, sK); Tensor tdPrdO = mma_partition_fragment_AB(wg_mma_SdP, sdO); Tensor tdPrV = mma_partition_fragment_AB(wg_mma_dP, sV); Tensor tdVrdO = mma_partition_fragment_AB(wg_mma_dKV, sdOt); Tensor tdKrQ = mma_partition_fragment_AB(wg_mma_dKV, sQt); Tensor tdQrdS = mma_partition_fragment_AB(wg_mma_dQ, sdS); Tensor tdQrK = mma_partition_fragment_AB(wg_mma_dQ, sKt); Tensor tPsP = smem_thr_copy_PdS.partition_D(cute::conditional_return(sP_pi, sPt_pi)); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor tdSsdS = smem_thr_copy_PdS.partition_D(cute::conditional_return(sdS_pi, sdSt_pi)); // ((Atom,AtomNum),PIPE_M,PIPE_N) // if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_PdS); print(sP_pi); printf("\n"); print(sPt_pi); printf("\n"); print(tPsP); printf("\n"); print(tdSsdS); printf("\n"); } // thread_mma_SdP.partition_C(sLSEMma) has shape ((2, 2, V), MMA_M, MMA_N, PIPE), we only take the col indices // or row indices, depending on whether SdP_swapAB. Tensor tLSEsLSE = cute::conditional_return( group_modes<0, 2>(thread_mma_SdP.partition_C(sLSEMma)(make_coord(_0{}, _, _0{}), _, _0{}, _)), // (2, MMA_M, PIPE) group_modes<0, 3>(thread_mma_SdP.partition_C(sLSEMma)(make_coord(_, _0{}, _), _0{}, _, _))); // (2, V, MMA_N, PIPE) Tensor tLSEsdPsum = cute::conditional_return( group_modes<0, 2>(thread_mma_SdP.partition_C(sdPsumMma)(make_coord(_0{}, _, _0{}), _, _0{}, _)), group_modes<0, 3>(thread_mma_SdP.partition_C(sdPsumMma)(make_coord(_, _0{}, _), _0{}, _, _))); // if (blockIdx.x == 0 && threadIdx.x == 128) { print(sLSEMma); printf("\n"); print(tLSEsLSE); printf("\n"); } // If we want to split the stats among the 8 threads that share the same rows. static constexpr int kStatsPerThread = cute::ceil_div(decltype(size(tLSEsLSE))::value, 8); auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); }; int bidh = get<1>(block_coord); int const seqlen_q = seqlen_info.seqlen_q; int const seqlen_k = seqlen_info.seqlen_k; // For the case where we do atomicAdd directly to gdQaccum instead of using TMA bool const is_varlen = Varlen && params.cu_seqlens_q; Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(_)); // (M * K, _) Tensor gdQaccum = cute::flat_divide(gdQaccum_, Int{}); // (M * K / WG, WG, _) // We can reuse r2s_thr_copy_dQaccum for this partitioning Tensor tdQgdQaccum = r2s_thr_copy_dQaccum.partition_D(gdQaccum); // if (blockIdx.x == 0 && threadIdx.x == 128) { print(mdQaccum); printf("\n"); print(gdQaccum_); printf("\n"); print(gdQaccum); printf("\n"); print(tdQgdQaccum); printf("\n"); } flash::Mask mask( thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length, params.qhead_per_khead_divmod ); int m_block = m_block_min; clear(tdKrdK); clear(tdVrdV); // tiled_mma_dKV.accumulate_ = GMMA::ScaleOut::Zero; cutlass::ConsumerToken barrier_token = static_cast(shared_storage.pipelines.barrier_KV.try_wait(work_idx % 2)); if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.pipelines.barrier_KV.wait(work_idx % 2); } if constexpr (Mma_dP_is_RS) { using SmemCopyAtomV = Copy_Atom; auto smem_tiled_copy_V = make_tiled_copy_A(SmemCopyAtomV{}, tiled_mma_dP); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(thread_idx); Tensor tdPrV_copy_view = smem_thr_copy_V.retile_D(tdPrV); Tensor tdPsV_copy_view = smem_thr_copy_V.partition_S(cute::as_position_independent_swizzle_tensor(sV)); cute::copy(smem_tiled_copy_V, tdPsV_copy_view, tdPrV_copy_view); } auto bwd_step = [&](int m_block, auto mask_fn) { Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select(TileShape_MNK{})); consumer_wait(pipeline_q, smem_pipe_read); flash::gemm(tiled_mma_SdP, tSrQ(_, _, _, smem_pipe_read.index()), tSrK, tSrS); Tensor tLSErLSE = cute::conditional_return(make_fragment_like(tLSEsLSE(_, _0{})), make_tensor(Int{})); if constexpr (!ShuffleLSE) { cute::copy(tLSEsLSE(_, smem_pipe_read.index()), tLSErLSE); } else { #pragma unroll for (int i = 0; i < kStatsPerThread; ++i) { // It's ok to read OOB, since we made sure sLSE is large enough and we won't use the OOB values tLSErLSE(i) = tLSEsLSE((thread_idx % 32) / 4 + i * 8, smem_pipe_read.index()); } } Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select(TileShape_MNK{})); PipelineState_dO smem_pipe_read_do_cur = cute::conditional_return(smem_pipe_read, smem_pipe_read_do); consumer_wait(pipeline_do, smem_pipe_read_do_cur); flash::gemm(tiled_mma_dP, tdPrdO(_, _, _, smem_pipe_read_do_cur.index()), tdPrV, tdPrdP); warpgroup_wait<1>(); if constexpr (Has_softcap) { flash::apply_softcap(tSrS, params.softcap_val); } // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); // dtanh needs to happen before masking, otherwise we get 1 - (-inf)^2 = NaN in the dtanh auto dtanh = [&] { if constexpr (Has_softcap) return flash::calculate_dtanh(scores); else return nullptr; }(); mask_fn(tSrS, m_block); #pragma unroll for (int mi = 0; mi < size<0>(scores); ++mi) { float const lse_scaled = [&] { if constexpr (!ShuffleLSE) return tLSErLSE(mi); else return __shfl_sync(0xffffffff, tLSErLSE(mi / 8), (mi % 8) * 4 + (thread_idx % 4)); }(); #pragma unroll for (int ni = 0; ni < size<1>(scores); ++ni) { scores(mi, ni) = exp2f(scores(mi, ni) * params.softmax_scale_log2 - lse_scaled); } } Tensor tLSErdPsum = cute::conditional_return(make_fragment_like(tLSEsdPsum(_, _0{})), make_tensor(Int{})); if constexpr (!ShuffledPSum) { cute::copy(tLSEsdPsum(_, smem_pipe_read_do_cur.index()), tLSErdPsum); } else { #pragma unroll for (int i = 0; i < kStatsPerThread; ++i) { tLSErdPsum(i) = tLSEsdPsum((thread_idx % 32) / 4 + i * 8, smem_pipe_read_do_cur.index()); } } warpgroup_wait<0>(); // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); #pragma unroll for (int mi = 0; mi < size<0>(dS); ++mi) { float const dP_sum_cur = [&] { if constexpr (!ShuffledPSum) return tLSErdPsum(mi); else return __shfl_sync(0xffffffff, tLSErdPsum(mi / 8), (mi % 8) * 4 + (thread_idx % 4)); }(); #pragma unroll for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum_cur); if constexpr (Has_softcap) { dS(mi, ni) *= dtanh(mi, ni); } } } // Convert scores from fp32 to fp16/bf16 Tensor rP = make_tensor_like(tSrS); flash::convert_type_out(tSrS, rP); if constexpr (!Mma_dKV_is_RS) { // Need to sync to make sure P has already been used in the previous iteration before writing new values if constexpr (kStages_dS == 1) { cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(BwdNamedBarriers::PdS) /*id*/); } Tensor tPaP = smem_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tPaP, tPsP(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index()))); } Tensor rdS = make_tensor_like(tdPrdP); flash::convert_type_out(tdPrdP, rdS); // If there's double buffering on dS, we don't need to sync here. // Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ. // But because both WGs have to sync at the end of the loop and double buffering, // this race condition is not possible. // This sync is to ensure (1) P is written in case of !Mma_dKV_is_RS and // (2) dS is already read by the Mma in the previous iteration in case of Mma_dKV_is_RS. if constexpr (!Mma_dKV_is_RS || (kStages_dS == 1 && Mma_dKV_is_RS)) { cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(BwdNamedBarriers::PdS) /*id*/); } // For hdim 64, It's faster to write to smem_dS first before the dV gemm Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index()))); if constexpr (!Slice_dQKV_Mma) { // Most cases take this path, except for hdim256 where we want to slice to reduce register pressure if constexpr (Mma_dKV_is_RS) { Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs(tSrS.layout())); flash::gemm(tiled_mma_dKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); } else { Tensor tdVrP = mma_partition_fragment_AB(wg_mma_dKV, sPt); Tensor tdVrP_cur = tdVrP(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); flash::gemm(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); } // SMEM fence to make sure sdS is written before it's read by WGMMA cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(BwdNamedBarriers::PdS) /*id*/); Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); Tensor tdQrdS_cur = tdQrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); flash::gemm(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ); pipeline_do.consumer_release(smem_pipe_read_do_cur); // release dQ if constexpr (Mma_dKV_is_RS) { Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); flash::gemm(tiled_mma_dKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); } else { Tensor tdKrdS = mma_partition_fragment_AB(wg_mma_dKV, sdSt); Tensor tdKrdS_cur = tdKrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); flash::gemm(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); } if constexpr (dQacc_use_TMA) { int const warp_group_idx = flash::canonical_warp_group_idx_nosync() - 1; cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmptyWG1) + warp_group_idx /*id*/); // sdQ full, to be written to gmem Tensor taccdQrdQ = r2s_thr_copy_dQaccum.retile_S(tdQrdQ); cute::copy(r2s_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum); cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQFullWG1) + warp_group_idx /*id*/); // sdQ full, to be written to gmem } else { // We can reuse r2s_thr_copy_dQaccum for this partitioning Tensor tdQrdQ_atomic = recast(r2s_thr_copy_dQaccum.retile_S(tdQrdQ)); Tensor tdQgdQaccum_atomic = recast(tdQgdQaccum(_, _, _, m_block)); static_assert(CUTE_STATIC_V(size(tdQrdQ_atomic)) == CUTE_STATIC_V(size(tdQgdQaccum_atomic))); #pragma unroll for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } } } else { // Slice_dQKV_Mma static_assert(!(Slice_dQKV_Mma && Mma_dKV_is_RS)); Tensor tdVrP = mma_partition_fragment_AB(wg_mma_dKV, sPt); Tensor tdVrP_cur = tdVrP(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); flash::gemm(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(BwdNamedBarriers::PdS) /*id*/); Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); Tensor tdQrdS_cur = tdQrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); flash::gemm(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ); flash::gemm(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); Tensor tdQrdQ_atomic = recast(r2s_thr_copy_dQaccum.retile_S(tdQrdQ)); Tensor tdQgdQaccum_atomic = recast(tdQgdQaccum(_, _, _, m_block)); #pragma unroll for (int i = 0; i < size(tdQrdQ_atomic) / 2; ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } Tensor tdKrdS = mma_partition_fragment_AB(wg_mma_dKV, sdSt); Tensor tdKrdS_cur = tdKrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); flash::gemm(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); pipeline_do.consumer_release(smem_pipe_read_do_cur); // release dO flash::gemm(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ); #pragma unroll for (int i = size(tdQrdQ_atomic) / 2; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } flash::gemm(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); } warpgroup_wait<0>(); pipeline_q.consumer_release(smem_pipe_read); // release Q ++smem_pipe_read; if constexpr (!Q_dO_same_stages) { ++smem_pipe_read_do; } }; // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64 // this helps quite a bit to not have to do causal masking for most of the iterations. if constexpr ((Is_causal || Is_local) && SeparateMaskingIterations) { auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; static constexpr int kBlockM = get<0>(TileShape_MNK{}); int const m_block_masking_max = ((n_block + 1) * kBlockN - 1 + seqlen_q - seqlen_k - params.window_size_right) / kBlockM + 1; CUTLASS_PRAGMA_NO_UNROLL for (; m_block < std::min(m_block_max, m_block_masking_max); ++m_block) { bwd_step(m_block, mask_fn); } } static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{}); int const m_block_max_before_local_mask = !Is_local || !SeparateMaskingIterations ? m_block_max : std::min(m_block_max, (n_block * kBlockN + seqlen_q - seqlen_k + params.window_size_left) / kBlockM); auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; CUTLASS_PRAGMA_NO_UNROLL for (; m_block < m_block_max_before_local_mask; ++m_block) { bwd_step(m_block, mask_fn); } if constexpr (Is_local && SeparateMaskingIterations) { auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; CUTLASS_PRAGMA_NO_UNROLL for (; m_block < m_block_max; ++m_block) { bwd_step(m_block, mask_fn); } } // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tdVrdV); } #pragma unroll for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.softmax_scale; } if constexpr (Q_dO_same_stages) { smem_pipe_read_do = smem_pipe_read; } ++work_idx; return true; } }; } // namespace flash