/****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include #include #include #include #include #include "cutlass/pipeline/pipeline.hpp" #include "cute/tensor.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "named_barrier.hpp" #include "softmax.h" #include "utils.h" namespace flash { using namespace cute; template struct CollectiveMainloopBwd { static constexpr int kStages = Stages; using ClusterShape = ClusterShape_; using TileShape_MNK = TileShape_MNK_; using Element = Element_; using ElementAccum = ElementAccum_; using ArchTag = ArchTag_; static constexpr bool Is_causal = Is_causal_; static constexpr bool Is_local = Is_local_; static constexpr bool Varlen = Varlen_; static constexpr bool SdP_swapAB = true; static constexpr bool dKV_swapAB = dKV_swapAB_; static constexpr bool dQ_swapAB = dQ_swapAB_; static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); static constexpr int NumdQWarpGroups = 2; static constexpr int kNThreadsdQ = NumdQWarpGroups * cutlass::NumThreadsPerWarpGroup; static_assert(ArchTag::kMinComputeCapability >= 90); static_assert(get<0>(ClusterShape{}) == 1 && get<2>(ClusterShape{}) == 1); static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS using TileShapeAtomSdP = std::conditional_t< !SdP_swapAB, Shape, Int, Int>, Shape, Int, Int> >; using AtomLayoutSdP = std::conditional_t< !SdP_swapAB, Layout, Int<2 / AtomLayoutMSdP>, _1>>, Layout, Int, _1>> >; using TiledMmaSdP = decltype(cute::make_tiled_mma( cute::GMMA::ss_op_selector(), AtomLayoutSdP{})); using TileShapeAtomdKV = std::conditional_t< !dKV_swapAB, Shape, Int, Int>, Shape, Int, Int> >; using AtomLayoutdKV = std::conditional_t< !dKV_swapAB, Layout, Int<2 / AtomLayoutNdKV>, _1>>, Layout, Int, _1>> >; using TiledMmadKV = decltype(cute::make_tiled_mma( std::conditional_t< !SdP_swapAB, decltype(cute::GMMA::ss_op_selector()), decltype(cute::GMMA::rs_op_selector()) >{}, AtomLayoutdKV{})); using TileShapeAtomdQ = std::conditional_t< !dQ_swapAB, Shape, Int, Int>, Shape, Int, Int> >; using AtomLayoutdQ = std::conditional_t< !dQ_swapAB, Layout, Int, _1>>, Layout, Int, _1>> >; static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN; static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K; using TiledMmadQ = decltype(cute::make_tiled_mma( std::conditional_t< !dQ_swapAB, std::conditional_t< Mma_dQ_is_RS, decltype(cute::GMMA::rs_op_selector()), decltype(cute::GMMA::ss_op_selector()) >, decltype(cute::GMMA::ss_op_selector()) >{}, AtomLayoutdQ{})); using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector, Int>()); using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutdO = SmemLayoutQ; using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector, Int>()); using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{}))); using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{}))); using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{}))); using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, make_shape(Int{}, Int{}, Int{}))); // Need stride to be multiple of 32, otherwise we get error (misaligned address) when doing TMA if e.g. kBlockM=80 using SmemLayoutLSE = cute::Layout, Int>, cute::Stride<_1, Int>>; using SmemLayoutLSEMma = cute::Layout, Int, Int>, cute::Stride<_0, _1, Int>>; // Note this is the transpose in terms of the view, not in terms of memory. using SmemLayoutQt = decltype(cute::composition(SmemLayoutQ{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), make_stride(Int{}, _1{}, Int{})))); using SmemLayoutdOt = decltype(cute::composition(SmemLayoutdO{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), make_stride(Int{}, _1{}, Int{})))); using SmemLayoutKt = decltype(cute::composition(SmemLayoutK{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), make_stride(Int{}, _1{})))); using SmemLayoutPt = decltype(cute::composition(SmemLayoutP{}, make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})), make_stride(Int{}, _1{})))); using SmemLayoutdSt = decltype(cute::composition(SmemLayoutdS{}, make_layout(make_shape(Int{}, Int{}, Int{}), make_stride(Int{}, _1{}, Int{})))); // Thread layout, 256 threads per row using R2SLayoutAtomdQaccum = Layout>, Stride<_1>>; using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom{}, R2SLayoutAtomdQaccum{}, Layout>{})); // Val layout, 4 vals per store using SmemLayoutdQaccum = Layout>, Stride<_1>>; // We want dQaccum smem to have last dimension 32, so that we only need to do 1 TMA instruction. // The layout Layout_K_SW128_Atom has 32 elements per row. // // TMA limit is that each dimension in smem must be <= 256. // static constexpr int ElemsPerRowTMA = (kBlockM * kHeadDim) / 32 <= 256 ? 32 : 64; static constexpr int ElemsPerRowTMA = 32; // If we change this, we'll also need to change the dQ shape in host. static_assert((kBlockM * kHeadDim) % ElemsPerRowTMA == 0); using TileShape_dQaccum = cute::Shape, Int>; // using TileShape_dQaccum = cute::Shape, Int>; using SmemLayoutdQaccumTMA = decltype(tile_to_shape(GMMA::Layout_K_SW128_Atom{}, TileShape_dQaccum{})); using SmemLayoutdQaccumTMANoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutdQaccumTMA{})); using SmemCopyAtomPdS = Copy_Atom< std::conditional_t, Element>; using SmemCopyAtomdKV = Copy_Atom< std::conditional_t, Element>; using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape{}))); using GmemTiledCopyKV = cute::SM90_TMA_LOAD; using GmemTiledCopydQaccum = cute::SM90_TMA_REDUCE_ADD; using GmemTiledCopyLSE = cute::SM90_TMA_LOAD; using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) using StrideQKV = cute::Stride; using ShapeLSE = cute::Shape; // (seqlen, head, batch) using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch) using TMA_QdO = decltype(make_tma_copy( GmemTiledCopyQdO{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), take<0, 2>(SmemLayoutQ{}), select<0, 2>(TileShape_MNK{}), size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any using TMA_K = decltype(make_tma_copy( GmemTiledCopyKV{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), SmemLayoutK{}, select<1, 2>(TileShape_MNK{}), _1{})); // no mcast for KV using TMA_V = decltype(make_tma_copy( GmemTiledCopyKV{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), SmemLayoutV{}, select<1, 2>(TileShape_MNK{}), _1{})); // no mcast for KV using TMA_add_dQ = decltype(make_tma_copy( GmemTiledCopydQaccum{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), SmemLayoutdQaccumTMA{}, TileShape_dQaccum{}, _1{})); // no mcast for dQ using TMA_LSE = decltype(make_tma_copy( GmemTiledCopyLSE{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeLSE{}, StrideLSE{}), select<0>(SmemLayoutLSE{}), select<0>(TileShape_MNK{}), _1{})); // no mcast for LSE static constexpr int NumMmaThreads = size(TiledMmaSdP{}); using MainloopPipeline = typename cutlass::PipelineTmaAsync; using PipelineState = typename MainloopPipeline::PipelineState; // Set the bytes transferred in this TMA transaction (may involve multiple issues) static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(take<0, 2>(SmemLayoutQ{})) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesK = static_cast(size(SmemLayoutK{}) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesV = static_cast(size(SmemLayoutV{}) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesLSE = static_cast(size(select<0>(SmemLayoutLSE{})) * cutlass::sizeof_bits_v / 8); struct TensorStorage : cute::aligned_struct<1024> { cute::array_aligned> smem_k; cute::array_aligned> smem_v; // It's important that smem_dqacc is aligned to 1024 bytes for the TMA, so that the 1st row // has no swizzle. // If the address is only 128 bytes aligned, it's possible that the 1st row has swizzle // and when we read it back in the postprocess kernel, the swizzle will not match. cute::array_aligned, 1024> smem_dqacc; cute::array_aligned> smem_q; cute::array_aligned> smem_do; cute::array_aligned> smem_ds; cute::array_aligned, 128> smem_lse; cute::array_aligned, 128> smem_dpsum; }; static constexpr int SharedStorageQdOSize = sizeof(decltype((TensorStorage{}).smem_q)) + sizeof(decltype((TensorStorage{}).smem_do)) + sizeof(decltype((TensorStorage{}).smem_ds)) + sizeof(decltype((TensorStorage{}).smem_dqacc)); // Host side kernel arguments struct Arguments { Element const* ptr_Q; ShapeQKV const shape_Q; StrideQKV const stride_Q; Element const* ptr_K; ShapeQKV const shape_K; StrideQKV const stride_K; Element const* ptr_V; StrideQKV const stride_V; Element const* ptr_dO; StrideQKV const stride_dO; ElementAccum* ptr_dQaccum; ShapeQKV const shape_dQaccum; StrideQKV const stride_dQaccum; float const* ptr_LSE_log2; ShapeLSE const shape_LSE; StrideLSE const stride_LSE_log2; float const* ptr_dPsum; StrideLSE const stride_dPsum; float const softmax_scale; int num_batch; int* dq_semaphore; int const* cu_seqlens_q = nullptr; int const* cu_seqlens_k = nullptr; int const* seqused_k = nullptr; int const* seqused_v = nullptr; int window_size_left; int window_size_right; }; // Device side kernel params struct Params { ShapeQKV const shape_Q; ShapeQKV const shape_K; ShapeQKV const shape_dQaccum; cutlass::FastDivmod qhead_per_khead_divmod; TMA_QdO tma_load_Q, tma_load_dO; TMA_K tma_load_K; TMA_V tma_load_V; TMA_add_dQ tma_add_dQ; TMA_LSE tma_load_LSE, tma_load_dPsum; float const* ptr_LSE_log2; ShapeLSE const shape_LSE; StrideLSE const stride_LSE_log2; float const* ptr_dPsum; StrideLSE const stride_dPsum; float const softmax_scale; float const softmax_scale_log2; int num_batch; int* dq_semaphore; int const* cu_seqlens_q = nullptr; int const* cu_seqlens_k = nullptr; int const* seqused_q = nullptr; int const* seqused_k = nullptr; int window_size_left; int window_size_right; }; static Params to_underlying_arguments(Arguments const& args) { Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q); TMA_QdO tma_load_Q = make_tma_copy( GmemTiledCopyQdO{}, mQ, SmemLayoutQ{}(_, _, _0{}), select<0, 2>(TileShape_MNK{}), size<1>(ClusterShape{})); // mcast along N mode for this M load, if any Tensor mdO = make_tensor(make_gmem_ptr(args.ptr_dO), args.shape_Q, args.stride_dO); TMA_QdO tma_load_dO = make_tma_copy( GmemTiledCopyQdO{}, mdO, SmemLayoutdO{}(_, _, _0{}), select<0, 2>(TileShape_MNK{}), size<1>(ClusterShape{})); // mcast along N mode for this M load, if any Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K); TMA_K tma_load_K = make_tma_copy( GmemTiledCopyKV{}, mK, SmemLayoutK{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for KV Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_K, args.stride_V); TMA_V tma_load_V = make_tma_copy( GmemTiledCopyKV{}, mV, SmemLayoutV{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for KV Tensor mdQaccum = make_tensor(make_gmem_ptr(args.ptr_dQaccum), args.shape_dQaccum, args.stride_dQaccum); TMA_add_dQ tma_add_dQ = make_tma_copy( GmemTiledCopydQaccum{}, mdQaccum, SmemLayoutdQaccumTMA{}, TileShape_dQaccum{}, _1{}); // no mcast for dQaccum Tensor mLSE = make_tensor(make_gmem_ptr(args.ptr_LSE_log2), args.shape_LSE, args.stride_LSE_log2); TMA_LSE tma_load_LSE = make_tma_copy( GmemTiledCopyLSE{}, mLSE, select<0>(SmemLayoutLSE{}), select<0>(TileShape_MNK{}), _1{}); // no mcast for LSE Tensor mdPsum = make_tensor(make_gmem_ptr(args.ptr_dPsum), args.shape_LSE, args.stride_dPsum); TMA_LSE tma_load_dPsum = make_tma_copy( GmemTiledCopyLSE{}, mdPsum, select<0>(SmemLayoutLSE{}), select<0>(TileShape_MNK{}), _1{}); // no mcast for dPsum if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); } return {args.shape_Q, args.shape_K, args.shape_dQaccum, cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), tma_load_Q, tma_load_dO, tma_load_K, tma_load_V, tma_add_dQ, tma_load_LSE, tma_load_dPsum, args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, args.softmax_scale, float(args.softmax_scale * M_LOG2E), args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_k, args.seqused_v, args.window_size_left, args.window_size_right}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& params) { cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_load_dO.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_load_LSE.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_load_dPsum.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_add_dQ.get_tma_descriptor()); } CUTLASS_DEVICE int get_seqlen_q(Params const& params, int bidb) { if constexpr (!Varlen) { return get<0>(params.shape_Q); } else { return params.cu_seqlens_q == nullptr ? get<0>(params.shape_Q) : (params.seqused_q ? params.seqused_q[bidb] : params.cu_seqlens_q[bidb + 1] - params.cu_seqlens_q[bidb] ); } } CUTLASS_DEVICE int get_seqlen_k(Params const& params, int bidb) { if constexpr (!Varlen) { return get<0>(params.shape_K); } else { return params.cu_seqlens_k == nullptr ? get<0>(params.shape_K) : (params.seqused_k ? params.seqused_k[bidb] : params.cu_seqlens_k[bidb + 1] - params.cu_seqlens_k[bidb] ); } } CUTLASS_DEVICE int get_m_block_min(Params const& params, int n_block, int bidb) { static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{}); if constexpr (Is_causal || Is_local) { int const seqlen_q = get_seqlen_q(params, bidb); int const seqlen_k = get_seqlen_k(params, bidb); return std::max(0, (n_block * kBlockN + seqlen_q - seqlen_k - params.window_size_right) / kBlockM); } else { return 0; } } CUTLASS_DEVICE int get_m_block_max(Params const& params, int n_block, int bidb) { static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{}); int const seqlen_q = get_seqlen_q(params, bidb); int const seqlen_k = get_seqlen_k(params, bidb); int m_block_max = cute::ceil_div(seqlen_q, kBlockM); if constexpr (Is_local) { return std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + params.window_size_left, kBlockM)); } else { return m_block_max; } } template CUTLASS_DEVICE void load(Params const& params, MainloopPipeline pipeline_q, MainloopPipeline pipeline_do, PipelineState& smem_pipe_write, SharedStorage &shared_storage, SchedulerPrefetch const& scheduler_prefetch, cute::tuple block_coord, int work_idx ) { Tensor sQ = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_q.data()), SmemLayoutQ{}); Tensor sdO = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_do.data()), SmemLayoutdO{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_v.data()), SmemLayoutV{}); Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_lse.data()), SmemLayoutLSE{}); Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_dpsum.data()), SmemLayoutLSE{}); auto [n_block, bidh, bidb] = block_coord; int bidh_kv = params.qhead_per_khead_divmod.divide(bidh); // Prepare the TMA loads uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; bool const is_varlen_q = Varlen && params.cu_seqlens_q != nullptr; bool const is_varlen_k = Varlen && params.cu_seqlens_k != nullptr; Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mdO = params.tma_load_dO.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mK = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); Tensor mLSE = params.tma_load_LSE.get_tma_tensor(params.shape_LSE)(_, bidh, !is_varlen_q ? bidb : 0); Tensor mdPsum = params.tma_load_dPsum.get_tma_tensor(params.shape_LSE)(_, bidh, !is_varlen_q ? bidb : 0); int const offset_q = !is_varlen_q ? 0 : params.cu_seqlens_q[bidb]; int const offset_k = !is_varlen_k ? 0 : params.cu_seqlens_k[bidb]; int const offset_padded = !is_varlen_q ? 0 : (params.cu_seqlens_q[bidb] + bidb * 128) / 128 * 128; Tensor gQ = local_tile(domain_offset(make_coord(offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) Tensor gdO = local_tile(domain_offset(make_coord(offset_q, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) Tensor gLSE = local_tile(domain_offset(make_coord(offset_padded), mLSE), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _) Tensor gdPsum = local_tile(domain_offset(make_coord(offset_padded), mdPsum), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _) Tensor sK_x = make_tensor(sK.data(), make_layout(sK.layout(), Layout<_1>{})); Tensor gK_x = make_tensor(gK.data(), make_layout(gK.layout(), Layout<_1>{})); Tensor sV_x = make_tensor(sV.data(), make_layout(sV.layout(), Layout<_1>{})); Tensor gV_x = make_tensor(gV.data(), make_layout(gV.layout(), Layout<_1>{})); auto [tQgQ, tQsQ] = tma_partition(params.tma_load_Q, block_rank_in_cluster, Layout{}, group_modes<0, 2>(sQ), group_modes<0, 2>(gQ)); // (TMA, k), (TMA, PIPE) auto [tdOgdO, tdOsdO] = tma_partition(params.tma_load_dO, block_rank_in_cluster, Layout{}, group_modes<0, 2>(sdO), group_modes<0, 2>(gdO)); // (TMA, k), (TMA, PIPE) auto [tKgK, tKsK] = tma_partition(params.tma_load_K, _0{}, Layout<_1>{}, group_modes<0, 2>(sK_x), group_modes<0, 2>(gK_x)); // (TMA), (TMA) auto [tVgV, tVsV] = tma_partition(params.tma_load_V, _0{}, Layout<_1>{}, group_modes<0, 2>(sV_x), group_modes<0, 2>(gV_x)); // (TMA), (TMA) auto [tLSEgLSE, tLSEsLSE] = tma_partition(params.tma_load_LSE, _0{}, Layout<_1>{}, sLSE, gLSE); // (TMA, k), (TMA, PIPE) auto [tLSEgdPsum, tLSEsdPsum] = tma_partition(params.tma_load_dPsum, _0{}, Layout<_1>{}, sdPsum, gdPsum); // (TMA, k), (TMA, PIPE) uint16_t mcast_mask_qdo = 0; if constexpr (cute::is_same_v) { auto block_layout = Layout{}; // (m,n) -> block_id for (int n = 0; n < size<1>(block_layout); ++n) { mcast_mask_qdo |= (uint16_t(1) << block_layout(n, cluster_local_block_id.x, _0{})); } } int m_block_max = get_m_block_max(params, n_block, bidb); int m_block_min = get_m_block_min(params, n_block, bidb); int m_block = m_block_min; int lane_predicate = cute::elect_one_sync(); // // Wait for the MMA warpgroups to say that smem_q is ready // cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::QueryEmpty) /*id*/); if (lane_predicate) { // Copy K tile and V tile from GMEM to SMEM. shared_storage.barrier_KV.arrive_and_expect_tx(TmaTransactionBytesK + TmaTransactionBytesV); copy(params.tma_load_K.with(reinterpret_cast(shared_storage.barrier_KV), 0 /*mcast_mask*/), tKgK, tKsK); copy(params.tma_load_V.with(reinterpret_cast(shared_storage.barrier_KV), 0 /*mcast_mask*/), tVgV, tVsV); pipeline_q.producer_acquire(smem_pipe_write); copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo), tQgQ(_, m_block), tQsQ(_, smem_pipe_write.index())); copy(params.tma_load_LSE.with(*pipeline_q.producer_get_barrier(smem_pipe_write), 0), tLSEgLSE(_, m_block), tLSEsLSE(_, smem_pipe_write.index())); #pragma unroll 2 for (; m_block < m_block_max - 1; ++m_block) { pipeline_do.producer_acquire(smem_pipe_write); copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write), mcast_mask_qdo), tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write.index())); copy(params.tma_load_dPsum.with(*pipeline_do.producer_get_barrier(smem_pipe_write), 0), tLSEgdPsum(_, m_block), tLSEsdPsum(_, smem_pipe_write.index())); ++smem_pipe_write; pipeline_q.producer_acquire(smem_pipe_write); copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo), tQgQ(_, m_block + 1), tQsQ(_, smem_pipe_write.index())); copy(params.tma_load_LSE.with(*pipeline_q.producer_get_barrier(smem_pipe_write), 0), tLSEgLSE(_, m_block + 1), tLSEsLSE(_, smem_pipe_write.index())); } } scheduler_prefetch(); if (lane_predicate) { pipeline_do.producer_acquire(smem_pipe_write); copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write), mcast_mask_qdo), tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write.index())); copy(params.tma_load_dPsum.with(*pipeline_do.producer_get_barrier(smem_pipe_write), 0), tLSEgdPsum(_, m_block), tLSEsdPsum(_, smem_pipe_write.index())); ++smem_pipe_write; } } /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline_q, MainloopPipeline pipeline_do, PipelineState& smem_pipe_write) { // Need to copy since pipeline_q.producer_tail(smem_pipe_write) will increment smem_pipe_write PipelineState smem_pipe_write_do = smem_pipe_write; int lane_predicate = cute::elect_one_sync(); // Issue the epilogue waits if (lane_predicate) { /* This helps avoid early exit of blocks in Cluster * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used * then would just be acquired since the phase was still inverted from make_producer_start_state */ pipeline_q.producer_tail(smem_pipe_write); pipeline_do.producer_tail(smem_pipe_write_do); } } template CUTLASS_DEVICE void store_dq(Params const& params, SharedStorage &shared_storage, cute::tuple block_coord ) { Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_dqacc.data()), SmemLayoutdQaccumTMA{}); Tensor sdQnoswizzle = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_dqacc.data()), SmemLayoutdQaccumTMANoSwizzle{}); auto [n_block, bidh, bidb] = block_coord; bool const is_varlen_q = Varlen && params.cu_seqlens_q != nullptr; // We reshaped dQaccum to have last dimension 32, so the offset needs to be multiplied by kHeadDim / 32 int const offset_padded = !is_varlen_q ? 0 : ((params.cu_seqlens_q[bidb] + bidb * 128) / 128 * 128) * (kHeadDim / ElemsPerRowTMA); // Prepare the TMA loads Tensor mdQaccum = params.tma_add_dQ.get_tma_tensor(params.shape_dQaccum)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor gdQaccum = local_tile(domain_offset(make_coord(offset_padded, _0{}), mdQaccum), TileShape_dQaccum{}, make_coord(_, _0{})); // (M, K, _) auto block_tma_dQ = params.tma_add_dQ.get_slice(_0{}); Tensor tdQgdQ = block_tma_dQ.partition_D(gdQaccum); // (TMA, TMA_M, TMA_K) Tensor tdQsdQ = block_tma_dQ.partition_S(sdQ); // (TMA, TMA_M, TMA_K) int m_block_max = get_m_block_max(params, n_block, bidb); int m_block_min = get_m_block_min(params, n_block, bidb); int m_block = m_block_min; int const num_batch = params.num_batch; int const num_head = get<2>(params.shape_Q); int *lock_ptr = !Deterministic ? nullptr : params.dq_semaphore + bidb * num_head + bidh; using Barrier = cutlass::GenericBarrier; int lane_predicate = cute::elect_one_sync(); #pragma unroll 2 for (; m_block < m_block_max; ++m_block) { if constexpr (Deterministic) { Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block); } cutlass::arch::NamedBarrier::sync(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQFull) /*id*/); // sdQ full, to be written to gmem if (lane_predicate) { cute::copy(params.tma_add_dQ, tdQsdQ, tdQgdQ(_, _, _, m_block)); tma_store_arrive(); } tma_store_wait<0>(); if constexpr (Deterministic) { Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head); } cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmpty) /*id*/); // sdQ empty, ready to be written to } if constexpr (Is_local && Deterministic) { constexpr int kBlockM = get<0>(TileShape_MNK{}); int const seqlen_q = get_seqlen_q(params, bidb); int const m_block_global_max = cute::ceil_div(seqlen_q, kBlockM); #pragma unroll 2 for (; m_block < m_block_global_max; ++m_block) { Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head); } } } CUTLASS_DEVICE void mma_init() { // // Tell producer (warp 0) that smem_q is ready // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::QueryEmpty) /*id*/); int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); if (cutlass::canonical_warp_group_idx() == 1 && warp_idx_in_warpgroup == 0) { cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmpty) /*id*/); // sdQ empty, ready to be written to } } template CUTLASS_DEVICE void mma(Params const& params, MainloopPipeline pipeline_q, MainloopPipeline pipeline_do, PipelineState& smem_pipe_read, FrgTensordKV& tdKrdK, FrgTensordKV& tdVrdV, int thread_idx, int work_idx, cute::tuple block_coord, SharedStorage& shared_storage ) { static_assert(is_rmem::value, "dK and dV tensor must be rmem resident."); Tensor sQ = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_q.data()), SmemLayoutQ{}); Tensor sdO = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_do.data()), SmemLayoutdO{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_v.data()), SmemLayoutV{}); Tensor sQt = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_q.data()), SmemLayoutQt{}); Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_do.data()), SmemLayoutdOt{}); Tensor sKt = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_k.data()), SmemLayoutKt{}); Tensor sdS = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_ds.data()), SmemLayoutdS{}); Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_ds.data()), SmemLayoutdSt{}); Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_dqacc.data()), SmemLayoutdQaccum{}); Tensor sLSEMma = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_lse.data()), SmemLayoutLSEMma{}); Tensor sdPsumMma = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_dpsum.data()), SmemLayoutLSEMma{}); static_assert(stride<0>(typename TiledMmaSdP::ALayout{}) == 0 and stride<0>(typename TiledMmaSdP::BLayout{}) == 0 and size<0>(typename TiledMmaSdP::ALayout{}) == cutlass::NumThreadsPerWarpGroup and size<0>(typename TiledMmaSdP::BLayout{}) == cutlass::NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); constexpr int MmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup; Layout warp_group_thread_layout = make_layout(make_shape(Int{}), make_stride(Int{})); Layout warp_group_thread_layout_dq = make_layout(make_shape(Int{}), make_stride(Int{})); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); TiledMmaSdP tiled_mma_SdP; TiledMmadKV tiled_mma_dKV; TiledMmadQ tiled_mma_dQ; static_assert(!dKV_swapAB); auto wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx)); auto thread_mma_SdP = tiled_mma_SdP.get_thread_slice(thread_idx); auto wg_mma_dKV = tiled_mma_dKV.get_slice(warp_group_thread_layout(warp_group_idx)); auto wg_mma_dQ = tiled_mma_dQ.get_slice(!Varlen ? warp_group_thread_layout_dq(NumdQWarpGroups == 2 ? warp_group_idx : 0) : thread_idx); // auto wg_mma_dQ = tiled_mma_dQ.get_thread_slice(thread_idx); auto smem_tiled_copy_PdS = make_tiled_copy_C(SmemCopyAtomPdS{}, tiled_mma_SdP); auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(thread_idx); Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdSt); // ((Atom,AtomNum),PIPE_M,PIPE_N) R2STiledCopydQaccum r2s_tiled_copy_dQaccum; // auto r2s_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_thread_slice(thread_idx); auto r2s_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_thread_slice(NumdQWarpGroups == 2 ? thread_idx : thread_idx % cutlass::NumThreadsPerWarpGroup); Tensor tdQsdQaccum = r2s_thr_copy_dQaccum.partition_D(sdQ); // Allocate "fragments/descriptors" Tensor tSrQ = wg_mma_SdP.partition_fragment_B(sQ); Tensor tSrK = wg_mma_SdP.partition_fragment_A(sK); Tensor tdPrdO = wg_mma_SdP.partition_fragment_B(sdO); Tensor tdPrV = wg_mma_SdP.partition_fragment_A(sV); Tensor tdVrdO = wg_mma_dKV.partition_fragment_B(sdOt); Tensor tdKrQ = wg_mma_dKV.partition_fragment_B(sQt); int n_block = get<0>(block_coord); int bidh = get<1>(block_coord); int bidb = get<2>(block_coord); int const seqlen_q = get_seqlen_q(params, bidb); int const seqlen_k = get_seqlen_k(params, bidb); int m_block_max = get_m_block_max(params, n_block, bidb); int m_block_min = get_m_block_min(params, n_block, bidb); int m_block = m_block_min; // thread_mma_SdP.partition_C(sLSEMma) has shape ((2, 2, V), MMA_M, MMA_N, PIPE), we only take the row indices. Tensor tLSEsLSE = thread_mma_SdP.partition_C(sLSEMma)(make_coord(_, _0{}, _), _0{}, _0{}, _); // (2, V, PIPE) Tensor tLSEsdPsum = thread_mma_SdP.partition_C(sdPsumMma)(make_coord(_, _0{}, _), _0{}, _0{}, _); clear(tdKrdK); clear(tdVrdV); // tiled_mma_dKV.accumulate_ = GMMA::ScaleOut::Zero; cutlass::ConsumerToken barrier_token = static_cast(shared_storage.barrier_KV.try_wait(work_idx % 2)); if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_KV.wait(work_idx % 2); } auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); }; auto compute_dQ = [&]() { static_assert(!Mma_dQ_is_RS); // SMEM fence to make sure sP is written before it's read by WGMMA cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::sync(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmpty) /*id*/); // sdQ empty, ready to be written to Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); if constexpr (!dQ_swapAB) { Tensor tdQrdS = wg_mma_dQ.partition_fragment_A(sdS); Tensor tdQrK = wg_mma_dQ.partition_fragment_B(sKt); flash::gemm(tiled_mma_dQ, tdQrdS(_, _, _, smem_pipe_read.index()), tdQrK, tdQrdQ); } else { Tensor tdQrdS = wg_mma_dQ.partition_fragment_B(sdS); Tensor tdQrK = wg_mma_dQ.partition_fragment_A(sKt); flash::gemm(tiled_mma_dQ, tdQrK, tdQrdS(_, _, _, smem_pipe_read.index()), tdQrdQ); } pipeline_q.consumer_release(smem_pipe_read); // release Q warpgroup_wait<0>(); Tensor taccdQrdQ = r2s_thr_copy_dQaccum.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N) cute::copy(r2s_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum); cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQFull) /*id*/); // sdQ full, to be written to gmem }; // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64 // this helps quite a bit to not have to do causal masking for most of the iterations. if constexpr (Is_causal) { static constexpr int n_masking_steps = cute::ceil_div(kBlockN, kBlockM) + 1; CUTLASS_PRAGMA_NO_UNROLL for (; m_block < std::min(m_block_max, m_block_min + n_masking_steps); ++m_block) { Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select<1, 0>(TileShape_MNK{})); pipeline_q.consumer_wait(smem_pipe_read); flash::gemm(tiled_mma_SdP, tSrK, tSrQ(_, _, _, smem_pipe_read.index()), tSrS); Tensor tLSErLSE = make_fragment_like(tLSEsLSE(_, _, _0{})); cute::copy(tLSEsLSE(_, _, smem_pipe_read.index()), tLSErLSE); Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select<1, 0>(TileShape_MNK{})); pipeline_do.consumer_wait(smem_pipe_read); flash::gemm(tiled_mma_SdP, tdPrV, tdPrdO(_, _, _, smem_pipe_read.index()), tdPrdP); warpgroup_wait<1>(); Tensor cS = cute::make_identity_tensor(select<1, 0>(TileShape_MNK{})); Tensor taccScS = thread_mma_SdP.partition_C(cS); int causal_row_offset = 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM; #pragma unroll for (int i = 0; i < size(tSrS); ++i) { if (int(get<0>(taccScS(i))) >= std::min(int(get<1>(taccScS(i))) + causal_row_offset, seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; } } // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout())); flash::scale_apply_exp2(scores, group_modes<0, 2>(tLSErLSE), params.softmax_scale_log2); Tensor tLSErdPsum = make_fragment_like(tLSEsdPsum(_, _, _0{})); cute::copy(tLSEsdPsum(_, _, smem_pipe_read.index()), tLSErdPsum); // Convert scores from fp32 to fp16/bf16 Tensor rP = flash::convert_type(tSrS); warpgroup_wait<0>(); // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); for (int mi = 0; mi < size<0>(dS); ++mi) { #pragma unroll for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - tLSErdPsum(mi)); } } Tensor rdS = flash::convert_type(tdPrdP); // Because of double buffering on dS, we don't need to sync here. // Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ. // But because both WGs have to sync at the end of the loop and double buffering, this race condition // is not possible. Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS(_, _, _, smem_pipe_read.index())); Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs(tSrS.layout())); flash::gemm(tiled_mma_dKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read.index()), tdVrdV); Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); flash::gemm(tiled_mma_dKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); pipeline_do.consumer_release(smem_pipe_read); // release dO compute_dQ(); ++smem_pipe_read; } } CUTLASS_PRAGMA_NO_UNROLL for (; m_block < m_block_max; ++m_block) { Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select<1, 0>(TileShape_MNK{})); pipeline_q.consumer_wait(smem_pipe_read); flash::gemm(tiled_mma_SdP, tSrK, tSrQ(_, _, _, smem_pipe_read.index()), tSrS); Tensor tLSErLSE = make_fragment_like(tLSEsLSE(_, _, _0{})); cute::copy(tLSEsLSE(_, _, smem_pipe_read.index()), tLSErLSE); Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select<1, 0>(TileShape_MNK{})); pipeline_do.consumer_wait(smem_pipe_read); flash::gemm(tiled_mma_SdP, tdPrV, tdPrdO(_, _, _, smem_pipe_read.index()), tdPrdP); warpgroup_wait<1>(); Tensor cS = cute::make_identity_tensor(select<1, 0>(TileShape_MNK{})); Tensor taccScS = thread_mma_SdP.partition_C(cS); if constexpr (!Is_local) { #pragma unroll for (int i = 0; i < size(tSrS); ++i) { if (int(get<0>(taccScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; } } } else { int local_row_offset_right = 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM + params.window_size_right; int local_row_offset_left = seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM - params.window_size_left; #pragma unroll for (int i = 0; i < size(tSrS); ++i) { if ((int(get<0>(taccScS(i))) >= std::min(int(get<1>(taccScS(i))) + local_row_offset_right, seqlen_k - n_block * kBlockN)) || (int(get<0>(taccScS(i))) < std::max(int(get<1>(taccScS(i))) + local_row_offset_left, 0))) { tSrS(i) = -INFINITY; } } } // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout())); // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tLSErLSE); } // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(scores); } flash::scale_apply_exp2(scores, group_modes<0, 2>(tLSErLSE), params.softmax_scale_log2); // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(scores); } Tensor tLSErdPsum = make_fragment_like(tLSEsdPsum(_, _, _0{})); cute::copy(tLSEsdPsum(_, _, smem_pipe_read.index()), tLSErdPsum); // Convert scores from fp32 to fp16/bf16 Tensor rP = flash::convert_type(tSrS); warpgroup_wait<0>(); // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); #pragma unroll for (int mi = 0; mi < size<0>(dS); ++mi) { #pragma unroll for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - tLSErdPsum(mi)); } } // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dS); } Tensor rdS = flash::convert_type(tdPrdP); Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS(_, _, _, smem_pipe_read.index())); Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs(tSrS.layout())); flash::gemm(tiled_mma_dKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read.index()), tdVrdV); Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); flash::gemm(tiled_mma_dKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); pipeline_do.consumer_release(smem_pipe_read); // release dO compute_dQ(); ++smem_pipe_read; } // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tdVrdV); } #pragma unroll for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.softmax_scale; } } }; } // namespace flash