123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851 |
- /******************************************************************************
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
- ******************************************************************************/
- #pragma once
- #include <cutlass/cutlass.h>
- #include <cutlass/array.h>
- #include <cutlass/numeric_types.h>
- #include <cutlass/numeric_conversion.h>
- #include <cutlass/barrier.h>
- #include "cutlass/pipeline/pipeline.hpp"
- #include "cute/tensor.hpp"
- #include "cutlass/gemm/collective/collective_builder.hpp"
- #include "named_barrier.hpp"
- #include "softmax.h"
- #include "utils.h"
- namespace flash {
- using namespace cute;
- template <int Stages, class ClusterShape_, class TileShape_MNK_, class Element_, class ElementAccum_, class ArchTag_,
- bool Is_causal_, bool Varlen_, bool Deterministic,
- bool dKV_swapAB_, bool dQ_swapAB_,
- int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1>
- struct CollectiveMainloopBwd {
- static constexpr int kStages = Stages;
- using ClusterShape = ClusterShape_;
- using TileShape_MNK = TileShape_MNK_;
- using Element = Element_;
- using ElementAccum = ElementAccum_;
- using ArchTag = ArchTag_;
- static constexpr bool Is_causal = Is_causal_;
- static constexpr bool Varlen = Varlen_;
- static constexpr bool SdP_swapAB = true;
- static constexpr bool dKV_swapAB = dKV_swapAB_;
- static constexpr bool dQ_swapAB = dQ_swapAB_;
- static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV
- static constexpr int kBlockM = get<0>(TileShape_MNK{});
- static constexpr int kBlockN = get<1>(TileShape_MNK{});
- static constexpr int kHeadDim = get<2>(TileShape_MNK{});
- static constexpr int NumdQWarpGroups = 2;
- static constexpr int kNThreadsdQ = NumdQWarpGroups * cutlass::NumThreadsPerWarpGroup;
- static_assert(ArchTag::kMinComputeCapability >= 90);
- static_assert(get<0>(ClusterShape{}) == 1 && get<2>(ClusterShape{}) == 1);
- static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS
- using TileShapeAtomSdP = std::conditional_t<
- !SdP_swapAB,
- Shape<Int<kBlockM>, Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kHeadDim>>,
- Shape<Int<kBlockN>, Int<kBlockM / AtomLayoutMSdP>, Int<kHeadDim>>
- >;
- using AtomLayoutSdP = std::conditional_t<
- !SdP_swapAB,
- Layout<Shape<Int<AtomLayoutMSdP>, Int<2 / AtomLayoutMSdP>, _1>>,
- Layout<Shape<Int<2 / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>
- >;
- using TiledMmaSdP = decltype(cute::make_tiled_mma(
- cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
- AtomLayoutSdP{}));
- using TileShapeAtomdKV = std::conditional_t<
- !dKV_swapAB,
- Shape<Int<kBlockN>, Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockM>>,
- Shape<Int<kHeadDim>, Int<kBlockN / AtomLayoutNdKV>, Int<kBlockM>>
- >;
- using AtomLayoutdKV = std::conditional_t<
- !dKV_swapAB,
- Layout<Shape<Int<AtomLayoutNdKV>, Int<2 / AtomLayoutNdKV>, _1>>,
- Layout<Shape<Int<2 / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>
- >;
- using TiledMmadKV = decltype(cute::make_tiled_mma(
- std::conditional_t<
- !SdP_swapAB,
- decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::MN, GMMA::Major::MN>()),
- decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::K, GMMA::Major::MN>())
- >{},
- AtomLayoutdKV{}));
- using TileShapeAtomdQ = std::conditional_t<
- !dQ_swapAB,
- Shape<Int<kBlockM>, Int<kHeadDim / (NumdQWarpGroups / AtomLayoutMdQ)>, Int<kBlockN>>,
- Shape<Int<kHeadDim>, Int<kBlockM / AtomLayoutMdQ>, Int<kBlockN>>
- >;
- using AtomLayoutdQ = std::conditional_t<
- !dQ_swapAB,
- Layout<Shape<Int<AtomLayoutMdQ>, Int<NumdQWarpGroups / AtomLayoutMdQ>, _1>>,
- Layout<Shape<Int<NumdQWarpGroups / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>
- >;
- static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN;
- static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K;
- using TiledMmadQ = decltype(cute::make_tiled_mma(
- std::conditional_t<
- !dQ_swapAB,
- std::conditional_t<
- Mma_dQ_is_RS,
- decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>()),
- decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>())
- >,
- decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::MN, GMMA::Major::K>())
- >{},
- AtomLayoutdQ{}));
- using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
- Int<kBlockM>, Int<dKV_swapAB ? kHeadDim : kHeadDim / (2 / AtomLayoutNdKV)>>());
- using SmemLayoutQ =
- decltype(tile_to_shape(SmemLayoutAtomQ{},
- make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
- using SmemLayoutdO = SmemLayoutQ;
- using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
- Int<kBlockN>, Int<dQ_swapAB ? kHeadDim : kHeadDim / (NumdQWarpGroups / AtomLayoutMdQ)>>());
- using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
- using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
- decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
- using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{})));
- using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
- decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
- using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{})));
- using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
- decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
- using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, make_shape(Int<kBlockM>{}, Int<kBlockN>{}, Int<kStages>{})));
- // Need stride to be multiple of 32, otherwise we get error (misaligned address) when doing TMA if e.g. kBlockM=80
- using SmemLayoutLSE = cute::Layout<cute::Shape<Int<kBlockM>, Int<kStages>>, cute::Stride<_1, Int<cute::round_up(kBlockM, 32)>>>;
- using SmemLayoutLSEMma = cute::Layout<cute::Shape<Int<kBlockN>, Int<kBlockM>, Int<kStages>>, cute::Stride<_0, _1, Int<cute::round_up(kBlockM, 32)>>>;
- // Note this is the transpose in terms of the view, not in terms of memory.
- using SmemLayoutQt =
- decltype(cute::composition(SmemLayoutQ{},
- make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
- make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
- using SmemLayoutdOt =
- decltype(cute::composition(SmemLayoutdO{},
- make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
- make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
- using SmemLayoutKt =
- decltype(cute::composition(SmemLayoutK{},
- make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
- make_stride(Int<kBlockN>{}, _1{}))));
- using SmemLayoutPt =
- decltype(cute::composition(SmemLayoutP{},
- make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
- make_stride(Int<kBlockM>{}, _1{}))));
- using SmemLayoutdSt =
- decltype(cute::composition(SmemLayoutdS{},
- make_layout(make_shape(Int<kBlockN>{}, Int<kBlockM>{}, Int<kStages>{}),
- make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kBlockN>{}))));
- // Thread layout, 256 threads per row
- using R2SLayoutAtomdQaccum = Layout<Shape<Int<kNThreadsdQ>>, Stride<_1>>;
- using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{}, R2SLayoutAtomdQaccum{},
- Layout<Shape < _4>>{})); // Val layout, 4 vals per store
- using SmemLayoutdQaccum = Layout<Shape<Int<kBlockM * kHeadDim>>, Stride<_1>>;
- // We want dQaccum smem to have last dimension 32, so that we only need to do 1 TMA instruction.
- // The layout Layout_K_SW128_Atom<ElementAccum> has 32 elements per row.
- // // TMA limit is that each dimension in smem must be <= 256.
- // static constexpr int ElemsPerRowTMA = (kBlockM * kHeadDim) / 32 <= 256 ? 32 : 64;
- static constexpr int ElemsPerRowTMA = 32; // If we change this, we'll also need to change the dQ shape in host.
- static_assert((kBlockM * kHeadDim) % ElemsPerRowTMA == 0);
- using TileShape_dQaccum = cute::Shape<Int<(kBlockM * kHeadDim) / ElemsPerRowTMA>, Int<ElemsPerRowTMA>>;
- // using TileShape_dQaccum = cute::Shape<Int<kBlockM>, Int<kHeadDim>>;
- using SmemLayoutdQaccumTMA =
- decltype(tile_to_shape(GMMA::Layout_K_SW128_Atom<ElementAccum>{}, TileShape_dQaccum{}));
- using SmemLayoutdQaccumTMANoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutdQaccumTMA{}));
- using SmemCopyAtomPdS = Copy_Atom<
- std::conditional_t<!SdP_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
- Element>;
- using SmemCopyAtomdKV = Copy_Atom<
- std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
- Element>;
- using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape{})));
- using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
- using GmemTiledCopydQaccum = cute::SM90_TMA_REDUCE_ADD;
- using GmemTiledCopyLSE = cute::SM90_TMA_LOAD;
- using ShapeQKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, head, batch)
- using StrideQKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
- using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen, head, batch)
- using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch)
- using TMA_QdO = decltype(make_tma_copy(
- GmemTiledCopyQdO{},
- make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQKV{}),
- take<0, 2>(SmemLayoutQ{}),
- select<0, 2>(TileShape_MNK{}),
- size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
- using TMA_K = decltype(make_tma_copy(
- GmemTiledCopyKV{},
- make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQKV{}),
- SmemLayoutK{},
- select<1, 2>(TileShape_MNK{}),
- _1{})); // no mcast for KV
- using TMA_V = decltype(make_tma_copy(
- GmemTiledCopyKV{},
- make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQKV{}),
- SmemLayoutV{},
- select<1, 2>(TileShape_MNK{}),
- _1{})); // no mcast for KV
- using TMA_add_dQ = decltype(make_tma_copy(
- GmemTiledCopydQaccum{},
- make_tensor(make_gmem_ptr(static_cast<ElementAccum*>(nullptr)), ShapeQKV{}, StrideQKV{}),
- SmemLayoutdQaccumTMA{},
- TileShape_dQaccum{},
- _1{})); // no mcast for dQ
- using TMA_LSE = decltype(make_tma_copy(
- GmemTiledCopyLSE{},
- make_tensor(make_gmem_ptr(static_cast<ElementAccum const*>(nullptr)), ShapeLSE{}, StrideLSE{}),
- select<0>(SmemLayoutLSE{}),
- select<0>(TileShape_MNK{}),
- _1{})); // no mcast for LSE
- static constexpr int NumMmaThreads = size(TiledMmaSdP{});
- using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
- using PipelineState = typename MainloopPipeline::PipelineState;
- // Set the bytes transferred in this TMA transaction (may involve multiple issues)
- static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutQ{})) * cutlass::sizeof_bits_v<Element> / 8);
- static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(SmemLayoutK{}) * cutlass::sizeof_bits_v<Element> / 8);
- static constexpr uint32_t TmaTransactionBytesV = static_cast<uint32_t>(size(SmemLayoutV{}) * cutlass::sizeof_bits_v<Element> / 8);
- static constexpr uint32_t TmaTransactionBytesLSE = static_cast<uint32_t>(size(select<0>(SmemLayoutLSE{})) * cutlass::sizeof_bits_v<ElementAccum> / 8);
- struct TensorStorage : cute::aligned_struct<1024> {
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
- // It's important that smem_dqacc is aligned to 1024 bytes for the TMA, so that the 1st row
- // has no swizzle.
- // If the address is only 128 bytes aligned, it's possible that the 1st row has swizzle
- // and when we read it back in the postprocess kernel, the swizzle will not match.
- cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdQaccum>, 1024> smem_dqacc;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
- cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutLSE>, 128> smem_lse;
- cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutLSE>, 128> smem_dpsum;
- };
- static constexpr int SharedStorageQdOSize = sizeof(decltype((TensorStorage{}).smem_q)) + sizeof(decltype((TensorStorage{}).smem_do)) + sizeof(decltype((TensorStorage{}).smem_ds)) + sizeof(decltype((TensorStorage{}).smem_dqacc));
- // Host side kernel arguments
- struct Arguments {
- Element const* ptr_Q;
- ShapeQKV const shape_Q;
- StrideQKV const stride_Q;
- Element const* ptr_K;
- ShapeQKV const shape_K;
- StrideQKV const stride_K;
- Element const* ptr_V;
- StrideQKV const stride_V;
- Element const* ptr_dO;
- StrideQKV const stride_dO;
- ElementAccum* ptr_dQaccum;
- ShapeQKV const shape_dQaccum;
- StrideQKV const stride_dQaccum;
- float const* ptr_LSE_log2;
- ShapeLSE const shape_LSE;
- StrideLSE const stride_LSE_log2;
- float const* ptr_dPsum;
- StrideLSE const stride_dPsum;
- float const softmax_scale;
- int num_batch;
- int* dq_semaphore;
- int const* cu_seqlens_q = nullptr;
- int const* cu_seqlens_k = nullptr;
- int const* seqused_k = nullptr;
- int const* seqused_v = nullptr;
- };
- // Device side kernel params
- struct Params {
- ShapeQKV const shape_Q;
- ShapeQKV const shape_K;
- ShapeQKV const shape_dQaccum;
- cutlass::FastDivmod qhead_per_khead_divmod;
- TMA_QdO tma_load_Q, tma_load_dO;
- TMA_K tma_load_K;
- TMA_V tma_load_V;
- TMA_add_dQ tma_add_dQ;
- TMA_LSE tma_load_LSE, tma_load_dPsum;
- float const* ptr_LSE_log2;
- ShapeLSE const shape_LSE;
- StrideLSE const stride_LSE_log2;
- float const* ptr_dPsum;
- StrideLSE const stride_dPsum;
- float const softmax_scale;
- float const softmax_scale_log2;
- int num_batch;
- int* dq_semaphore;
- int const* cu_seqlens_q = nullptr;
- int const* cu_seqlens_k = nullptr;
- int const* seqused_q = nullptr;
- int const* seqused_k = nullptr;
- };
- static Params
- to_underlying_arguments(Arguments const& args) {
- Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q);
- TMA_QdO tma_load_Q = make_tma_copy(
- GmemTiledCopyQdO{},
- mQ,
- SmemLayoutQ{}(_, _, _0{}),
- select<0, 2>(TileShape_MNK{}),
- size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
- Tensor mdO = make_tensor(make_gmem_ptr(args.ptr_dO), args.shape_Q, args.stride_dO);
- TMA_QdO tma_load_dO = make_tma_copy(
- GmemTiledCopyQdO{},
- mdO,
- SmemLayoutdO{}(_, _, _0{}),
- select<0, 2>(TileShape_MNK{}),
- size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
- Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K);
- TMA_K tma_load_K = make_tma_copy(
- GmemTiledCopyKV{},
- mK,
- SmemLayoutK{},
- select<1, 2>(TileShape_MNK{}),
- _1{}); // no mcast for KV
- Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_K, args.stride_V);
- TMA_V tma_load_V = make_tma_copy(
- GmemTiledCopyKV{},
- mV,
- SmemLayoutV{},
- select<1, 2>(TileShape_MNK{}),
- _1{}); // no mcast for KV
- Tensor mdQaccum = make_tensor(make_gmem_ptr(args.ptr_dQaccum), args.shape_dQaccum, args.stride_dQaccum);
- TMA_add_dQ tma_add_dQ = make_tma_copy(
- GmemTiledCopydQaccum{},
- mdQaccum,
- SmemLayoutdQaccumTMA{},
- TileShape_dQaccum{},
- _1{}); // no mcast for dQaccum
- Tensor mLSE = make_tensor(make_gmem_ptr(args.ptr_LSE_log2), args.shape_LSE, args.stride_LSE_log2);
- TMA_LSE tma_load_LSE = make_tma_copy(
- GmemTiledCopyLSE{},
- mLSE,
- select<0>(SmemLayoutLSE{}),
- select<0>(TileShape_MNK{}),
- _1{}); // no mcast for LSE
- Tensor mdPsum = make_tensor(make_gmem_ptr(args.ptr_dPsum), args.shape_LSE, args.stride_dPsum);
- TMA_LSE tma_load_dPsum = make_tma_copy(
- GmemTiledCopyLSE{},
- mdPsum,
- select<0>(SmemLayoutLSE{}),
- select<0>(TileShape_MNK{}),
- _1{}); // no mcast for dPsum
- if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); }
- return {args.shape_Q, args.shape_K, args.shape_dQaccum,
- cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))),
- tma_load_Q, tma_load_dO, tma_load_K, tma_load_V, tma_add_dQ, tma_load_LSE, tma_load_dPsum,
- args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum,
- args.softmax_scale, float(args.softmax_scale * M_LOG2E),
- args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k,
- args.seqused_k, args.seqused_v};
- }
- /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
- CUTLASS_DEVICE
- static void prefetch_tma_descriptors(Params const& params) {
- cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor());
- cute::prefetch_tma_descriptor(params.tma_load_dO.get_tma_descriptor());
- cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor());
- cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor());
- cute::prefetch_tma_descriptor(params.tma_load_LSE.get_tma_descriptor());
- cute::prefetch_tma_descriptor(params.tma_load_dPsum.get_tma_descriptor());
- cute::prefetch_tma_descriptor(params.tma_add_dQ.get_tma_descriptor());
- }
- CUTLASS_DEVICE
- int get_seqlen_q(Params const& params, int bidb) {
- if constexpr (!Varlen) {
- return get<0>(params.shape_Q);
- } else {
- return params.cu_seqlens_q == nullptr
- ? get<0>(params.shape_Q)
- : (params.seqused_q
- ? params.seqused_q[bidb]
- : params.cu_seqlens_q[bidb + 1] - params.cu_seqlens_q[bidb]
- );
- }
- }
- CUTLASS_DEVICE
- int get_seqlen_k(Params const& params, int bidb) {
- if constexpr (!Varlen) {
- return get<0>(params.shape_K);
- } else {
- return params.cu_seqlens_k == nullptr
- ? get<0>(params.shape_K)
- : (params.seqused_k
- ? params.seqused_k[bidb]
- : params.cu_seqlens_k[bidb + 1] - params.cu_seqlens_k[bidb]
- );
- }
- }
- CUTLASS_DEVICE
- int get_m_block_min(Params const& params, int n_block, int bidb) {
- if constexpr (Is_causal) {
- int const seqlen_q = get_seqlen_q(params, bidb);
- int const seqlen_k = get_seqlen_k(params, bidb);
- return std::max(0, (n_block * kBlockN + seqlen_q - seqlen_k) / kBlockM);
- } else {
- return 0;
- }
- }
- template <typename SchedulerPrefetch, typename SharedStorage>
- CUTLASS_DEVICE void
- load(Params const& params,
- MainloopPipeline pipeline_q,
- MainloopPipeline pipeline_do,
- PipelineState& smem_pipe_write,
- SharedStorage &shared_storage,
- SchedulerPrefetch const& scheduler_prefetch,
- cute::tuple<int32_t, int32_t, int32_t> block_coord,
- int work_idx
- ) {
- Tensor sQ = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_q.data()), SmemLayoutQ{});
- Tensor sdO = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_do.data()), SmemLayoutdO{});
- Tensor sK = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_k.data()), SmemLayoutK{});
- Tensor sV = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_v.data()), SmemLayoutV{});
- Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_lse.data()), SmemLayoutLSE{});
- Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_dpsum.data()), SmemLayoutLSE{});
- auto [n_block, bidh, bidb] = block_coord;
- int bidh_kv = params.qhead_per_khead_divmod.divide(bidh);
- // Prepare the TMA loads
- uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
- constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
- uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
- bool const is_varlen_q = Varlen && params.cu_seqlens_q != nullptr;
- bool const is_varlen_k = Varlen && params.cu_seqlens_k != nullptr;
- Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0);
- Tensor mdO = params.tma_load_dO.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0);
- Tensor mK = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0);
- Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0);
- Tensor mLSE = params.tma_load_LSE.get_tma_tensor(params.shape_LSE)(_, bidh, !is_varlen_q ? bidb : 0);
- Tensor mdPsum = params.tma_load_dPsum.get_tma_tensor(params.shape_LSE)(_, bidh, !is_varlen_q ? bidb : 0);
- int const offset_q = !is_varlen_q ? 0 : params.cu_seqlens_q[bidb];
- int const offset_k = !is_varlen_k ? 0 : params.cu_seqlens_k[bidb];
- int const offset_padded = !is_varlen_q ? 0 : (params.cu_seqlens_q[bidb] + bidb * 128) / 128 * 128;
- Tensor gQ = local_tile(domain_offset(make_coord(offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
- Tensor gdO = local_tile(domain_offset(make_coord(offset_q, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
- Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K)
- Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K)
- Tensor gLSE = local_tile(domain_offset(make_coord(offset_padded), mLSE), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _)
- Tensor gdPsum = local_tile(domain_offset(make_coord(offset_padded), mdPsum), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _)
- Tensor sK_x = make_tensor(sK.data(), make_layout(sK.layout(), Layout<_1>{}));
- Tensor gK_x = make_tensor(gK.data(), make_layout(gK.layout(), Layout<_1>{}));
- Tensor sV_x = make_tensor(sV.data(), make_layout(sV.layout(), Layout<_1>{}));
- Tensor gV_x = make_tensor(gV.data(), make_layout(gV.layout(), Layout<_1>{}));
- auto [tQgQ, tQsQ] = tma_partition(params.tma_load_Q, block_rank_in_cluster, Layout<ClusterShape>{},
- group_modes<0, 2>(sQ), group_modes<0, 2>(gQ)); // (TMA, k), (TMA, PIPE)
- auto [tdOgdO, tdOsdO] = tma_partition(params.tma_load_dO, block_rank_in_cluster, Layout<ClusterShape>{},
- group_modes<0, 2>(sdO), group_modes<0, 2>(gdO)); // (TMA, k), (TMA, PIPE)
- auto [tKgK, tKsK] = tma_partition(params.tma_load_K, _0{}, Layout<_1>{},
- group_modes<0, 2>(sK_x), group_modes<0, 2>(gK_x)); // (TMA), (TMA)
- auto [tVgV, tVsV] = tma_partition(params.tma_load_V, _0{}, Layout<_1>{},
- group_modes<0, 2>(sV_x), group_modes<0, 2>(gV_x)); // (TMA), (TMA)
- auto [tLSEgLSE, tLSEsLSE] = tma_partition(params.tma_load_LSE, _0{}, Layout<_1>{},
- sLSE, gLSE); // (TMA, k), (TMA, PIPE)
- auto [tLSEgdPsum, tLSEsdPsum] = tma_partition(params.tma_load_dPsum, _0{}, Layout<_1>{},
- sdPsum, gdPsum); // (TMA, k), (TMA, PIPE)
- uint16_t mcast_mask_qdo = 0;
- if constexpr (cute::is_same_v<GmemTiledCopyQdO, SM90_TMA_LOAD_MULTICAST>) {
- auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
- for (int n = 0; n < size<1>(block_layout); ++n) {
- mcast_mask_qdo |= (uint16_t(1) << block_layout(n, cluster_local_block_id.x, _0{}));
- }
- }
- int m_block_max = cute::ceil_div(get_seqlen_q(params, bidb), get<0>(TileShape_MNK{}));
- int m_block_min = get_m_block_min(params, n_block, bidb);
- int m_block = m_block_min;
- int lane_predicate = cute::elect_one_sync();
- // // Wait for the MMA warpgroups to say that smem_q is ready
- // cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::QueryEmpty) /*id*/);
- if (lane_predicate) {
- // Copy K tile and V tile from GMEM to SMEM.
- shared_storage.barrier_KV.arrive_and_expect_tx(TmaTransactionBytesK + TmaTransactionBytesV);
- copy(params.tma_load_K.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_KV), 0 /*mcast_mask*/), tKgK, tKsK);
- copy(params.tma_load_V.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_KV), 0 /*mcast_mask*/), tVgV, tVsV);
- pipeline_q.producer_acquire(smem_pipe_write);
- copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo), tQgQ(_, m_block), tQsQ(_, smem_pipe_write.index()));
- copy(params.tma_load_LSE.with(*pipeline_q.producer_get_barrier(smem_pipe_write), 0), tLSEgLSE(_, m_block), tLSEsLSE(_, smem_pipe_write.index()));
- #pragma unroll 2
- for (; m_block < m_block_max - 1; ++m_block) {
- pipeline_do.producer_acquire(smem_pipe_write);
- copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write), mcast_mask_qdo), tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write.index()));
- copy(params.tma_load_dPsum.with(*pipeline_do.producer_get_barrier(smem_pipe_write), 0), tLSEgdPsum(_, m_block), tLSEsdPsum(_, smem_pipe_write.index()));
- ++smem_pipe_write;
- pipeline_q.producer_acquire(smem_pipe_write);
- copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo), tQgQ(_, m_block + 1), tQsQ(_, smem_pipe_write.index()));
- copy(params.tma_load_LSE.with(*pipeline_q.producer_get_barrier(smem_pipe_write), 0), tLSEgLSE(_, m_block + 1), tLSEsLSE(_, smem_pipe_write.index()));
- }
- }
- scheduler_prefetch();
- if (lane_predicate) {
- pipeline_do.producer_acquire(smem_pipe_write);
- copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write), mcast_mask_qdo), tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write.index()));
- copy(params.tma_load_dPsum.with(*pipeline_do.producer_get_barrier(smem_pipe_write), 0), tLSEgdPsum(_, m_block), tLSEsdPsum(_, smem_pipe_write.index()));
- ++smem_pipe_write;
- }
- }
- /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
- CUTLASS_DEVICE void
- load_tail(MainloopPipeline pipeline_q, MainloopPipeline pipeline_do,
- PipelineState& smem_pipe_write) {
- // Need to copy since pipeline_q.producer_tail(smem_pipe_write) will increment smem_pipe_write
- PipelineState smem_pipe_write_do = smem_pipe_write;
- int lane_predicate = cute::elect_one_sync();
- // Issue the epilogue waits
- if (lane_predicate) {
- /* This helps avoid early exit of blocks in Cluster
- * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
- * then would just be acquired since the phase was still inverted from make_producer_start_state
- */
- pipeline_q.producer_tail(smem_pipe_write);
- pipeline_do.producer_tail(smem_pipe_write_do);
- }
- }
- template <typename SharedStorage>
- CUTLASS_DEVICE void
- store_dq(Params const& params,
- SharedStorage &shared_storage,
- cute::tuple<int32_t, int32_t, int32_t> block_coord
- ) {
- Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_dqacc.data()), SmemLayoutdQaccumTMA{});
- Tensor sdQnoswizzle = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_dqacc.data()), SmemLayoutdQaccumTMANoSwizzle{});
- auto [n_block, bidh, bidb] = block_coord;
- bool const is_varlen_q = Varlen && params.cu_seqlens_q != nullptr;
- // We reshaped dQaccum to have last dimension 32, so the offset needs to be multiplied by kHeadDim / 32
- int const offset_padded = !is_varlen_q ? 0 : ((params.cu_seqlens_q[bidb] + bidb * 128) / 128 * 128) * (kHeadDim / ElemsPerRowTMA);
- // Prepare the TMA loads
- Tensor mdQaccum = params.tma_add_dQ.get_tma_tensor(params.shape_dQaccum)(_, _, bidh, !is_varlen_q ? bidb : 0);
- Tensor gdQaccum = local_tile(domain_offset(make_coord(offset_padded, _0{}), mdQaccum), TileShape_dQaccum{}, make_coord(_, _0{})); // (M, K, _)
- auto block_tma_dQ = params.tma_add_dQ.get_slice(_0{});
- Tensor tdQgdQ = block_tma_dQ.partition_D(gdQaccum); // (TMA, TMA_M, TMA_K)
- Tensor tdQsdQ = block_tma_dQ.partition_S(sdQ); // (TMA, TMA_M, TMA_K)
- int m_block_max = cute::ceil_div(get_seqlen_q(params, bidb), get<0>(TileShape_MNK{}));
- int m_block_min = get_m_block_min(params, n_block, bidb);
- int m_block = m_block_min;
- int const num_batch = params.num_batch;
- int const num_head = get<2>(params.shape_Q);
- int *lock_ptr = !Deterministic ? nullptr : params.dq_semaphore + bidb * num_head + bidh;
- using Barrier = cutlass::GenericBarrier<cutlass::detail::SyncwarpSync>;
- int lane_predicate = cute::elect_one_sync();
- #pragma unroll 2
- for (; m_block < m_block_max; ++m_block) {
- if constexpr (Deterministic) {
- Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block);
- }
- cutlass::arch::NamedBarrier::sync(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQFull) /*id*/); // sdQ full, to be written to gmem
- if (lane_predicate) {
- cute::copy(params.tma_add_dQ, tdQsdQ, tdQgdQ(_, _, _, m_block));
- tma_store_arrive();
- }
- tma_store_wait<0>();
- if constexpr (Deterministic) {
- Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head);
- }
- cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQEmpty) /*id*/); // sdQ empty, ready to be written to
- }
- }
- CUTLASS_DEVICE void
- mma_init() {
- // // Tell producer (warp 0) that smem_q is ready
- // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::QueryEmpty) /*id*/);
- int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
- if (cutlass::canonical_warp_group_idx() == 1 && warp_idx_in_warpgroup == 0) {
- cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQEmpty) /*id*/); // sdQ empty, ready to be written to
- }
- }
- template <typename SharedStorage, typename FrgTensordKV>
- CUTLASS_DEVICE void
- mma(Params const& params,
- MainloopPipeline pipeline_q,
- MainloopPipeline pipeline_do,
- PipelineState& smem_pipe_read,
- FrgTensordKV& tdKrdK,
- FrgTensordKV& tdVrdV,
- int thread_idx,
- int work_idx,
- cute::tuple<int32_t, int32_t, int32_t> block_coord,
- SharedStorage& shared_storage
- ) {
- static_assert(is_rmem<FrgTensordKV>::value, "dK and dV tensor must be rmem resident.");
- Tensor sQ = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_q.data()), SmemLayoutQ{});
- Tensor sdO = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_do.data()), SmemLayoutdO{});
- Tensor sK = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_k.data()), SmemLayoutK{});
- Tensor sV = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_v.data()), SmemLayoutV{});
- Tensor sQt = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_q.data()), SmemLayoutQt{});
- Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_do.data()), SmemLayoutdOt{});
- Tensor sKt = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_k.data()), SmemLayoutKt{});
- Tensor sdS = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_ds.data()), SmemLayoutdS{});
- Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_ds.data()), SmemLayoutdSt{});
- Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_dqacc.data()), SmemLayoutdQaccum{});
- Tensor sLSEMma = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_lse.data()), SmemLayoutLSEMma{});
- Tensor sdPsumMma = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_dpsum.data()), SmemLayoutLSEMma{});
- static_assert(stride<0>(typename TiledMmaSdP::ALayout{}) == 0 and
- stride<0>(typename TiledMmaSdP::BLayout{}) == 0 and
- size<0>(typename TiledMmaSdP::ALayout{}) == cutlass::NumThreadsPerWarpGroup and
- size<0>(typename TiledMmaSdP::BLayout{}) == cutlass::NumThreadsPerWarpGroup,
- "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
- constexpr int MmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup;
- Layout warp_group_thread_layout = make_layout(make_shape(Int<MmaWarpGroups>{}),
- make_stride(Int<cutlass::NumThreadsPerWarpGroup>{}));
- Layout warp_group_thread_layout_dq = make_layout(make_shape(Int<NumdQWarpGroups>{}),
- make_stride(Int<cutlass::NumThreadsPerWarpGroup>{}));
- int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0);
- TiledMmaSdP tiled_mma_SdP;
- TiledMmadKV tiled_mma_dKV;
- TiledMmadQ tiled_mma_dQ;
- static_assert(!dKV_swapAB);
- auto wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx));
- auto thread_mma_SdP = tiled_mma_SdP.get_thread_slice(thread_idx);
- auto wg_mma_dKV = tiled_mma_dKV.get_slice(warp_group_thread_layout(warp_group_idx));
- auto wg_mma_dQ = tiled_mma_dQ.get_slice(!Varlen ? warp_group_thread_layout_dq(NumdQWarpGroups == 2 ? warp_group_idx : 0) : thread_idx);
- // auto wg_mma_dQ = tiled_mma_dQ.get_thread_slice(thread_idx);
- auto smem_tiled_copy_PdS = make_tiled_copy_C(SmemCopyAtomPdS{}, tiled_mma_SdP);
- auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(thread_idx);
- Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdSt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
- R2STiledCopydQaccum r2s_tiled_copy_dQaccum;
- // auto r2s_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_thread_slice(thread_idx);
- auto r2s_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_thread_slice(NumdQWarpGroups == 2 ? thread_idx : thread_idx % cutlass::NumThreadsPerWarpGroup);
- Tensor tdQsdQaccum = r2s_thr_copy_dQaccum.partition_D(sdQ);
- // Allocate "fragments/descriptors"
- Tensor tSrQ = wg_mma_SdP.partition_fragment_B(sQ);
- Tensor tSrK = wg_mma_SdP.partition_fragment_A(sK);
- Tensor tdPrdO = wg_mma_SdP.partition_fragment_B(sdO);
- Tensor tdPrV = wg_mma_SdP.partition_fragment_A(sV);
- Tensor tdVrdO = wg_mma_dKV.partition_fragment_B(sdOt);
- Tensor tdKrQ = wg_mma_dKV.partition_fragment_B(sQt);
- int n_block = get<0>(block_coord);
- int bidh = get<1>(block_coord);
- int bidb = get<2>(block_coord);
- int const seqlen_q = get_seqlen_q(params, bidb);
- int const seqlen_k = get_seqlen_k(params, bidb);
- int m_block_max = cute::ceil_div(get_seqlen_q(params, bidb), get<0>(TileShape_MNK{}));
- int m_block_min = get_m_block_min(params, n_block, bidb);
- int m_block = m_block_min;
- // thread_mma_SdP.partition_C(sLSEMma) has shape ((2, 2, V), MMA_M, MMA_N, PIPE), we only take the row indices.
- Tensor tLSEsLSE = thread_mma_SdP.partition_C(sLSEMma)(make_coord(_, _0{}, _), _0{}, _0{}, _); // (2, V, PIPE)
- Tensor tLSEsdPsum = thread_mma_SdP.partition_C(sdPsumMma)(make_coord(_, _0{}, _), _0{}, _0{}, _);
- clear(tdKrdK);
- clear(tdVrdV);
- // tiled_mma_dKV.accumulate_ = GMMA::ScaleOut::Zero;
- cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_KV.try_wait(work_idx % 2));
- if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_KV.wait(work_idx % 2); }
- auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
- auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
- pipeline.consumer_wait(smem_pipe_read, barrier_token);
- };
- auto compute_dQ = [&]() {
- static_assert(!Mma_dQ_is_RS);
- // SMEM fence to make sure sP is written before it's read by WGMMA
- cutlass::arch::fence_view_async_shared();
- cutlass::arch::NamedBarrier::sync(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQEmpty) /*id*/); // sdQ empty, ready to be written to
- Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
- if constexpr (!dQ_swapAB) {
- Tensor tdQrdS = wg_mma_dQ.partition_fragment_A(sdS);
- Tensor tdQrK = wg_mma_dQ.partition_fragment_B(sKt);
- flash::gemm</*zero_init=*/true, /*wg_wait=*/1>(tiled_mma_dQ, tdQrdS(_, _, _, smem_pipe_read.index()), tdQrK, tdQrdQ);
- } else {
- Tensor tdQrdS = wg_mma_dQ.partition_fragment_B(sdS);
- Tensor tdQrK = wg_mma_dQ.partition_fragment_A(sKt);
- flash::gemm</*zero_init=*/true, /*wg_wait=*/1>(tiled_mma_dQ, tdQrK, tdQrdS(_, _, _, smem_pipe_read.index()), tdQrdQ);
- }
- pipeline_q.consumer_release(smem_pipe_read); // release Q
- warpgroup_wait<0>();
- Tensor taccdQrdQ = r2s_thr_copy_dQaccum.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N)
- cute::copy(r2s_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum);
- cutlass::arch::fence_view_async_shared();
- cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQFull) /*id*/); // sdQ full, to be written to gmem
- };
- // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64
- // this helps quite a bit to not have to do causal masking for most of the iterations.
- if constexpr (Is_causal) {
- static constexpr int n_masking_steps = cute::ceil_div(kBlockN, kBlockM) + 1;
- CUTLASS_PRAGMA_NO_UNROLL
- for (; m_block < std::min(m_block_max, m_block_min + n_masking_steps); ++m_block) {
- Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select<1, 0>(TileShape_MNK{}));
- pipeline_q.consumer_wait(smem_pipe_read);
- flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma_SdP, tSrK, tSrQ(_, _, _, smem_pipe_read.index()), tSrS);
- Tensor tLSErLSE = make_fragment_like(tLSEsLSE(_, _, _0{}));
- cute::copy(tLSEsLSE(_, _, smem_pipe_read.index()), tLSErLSE);
- Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select<1, 0>(TileShape_MNK{}));
- pipeline_do.consumer_wait(smem_pipe_read);
- flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma_SdP, tdPrV, tdPrdO(_, _, _, smem_pipe_read.index()), tdPrdP);
- warpgroup_wait<1>();
- Tensor cS = cute::make_identity_tensor(select<1, 0>(TileShape_MNK{}));
- Tensor taccScS = thread_mma_SdP.partition_C(cS);
- int causal_row_offset = 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM;
- #pragma unroll
- for (int i = 0; i < size(tSrS); ++i) {
- if (int(get<0>(taccScS(i))) >= std::min(int(get<1>(taccScS(i))) + causal_row_offset,
- seqlen_k - n_block * kBlockN)) {
- tSrS(i) = -INFINITY;
- }
- }
- // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
- Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout()));
- flash::scale_apply_exp2</*Scale_max=*/false, /*Check_inf=*/false>(scores, group_modes<0, 2>(tLSErLSE), params.softmax_scale_log2);
- Tensor tLSErdPsum = make_fragment_like(tLSEsdPsum(_, _, _0{}));
- cute::copy(tLSEsdPsum(_, _, smem_pipe_read.index()), tLSErdPsum);
- // Convert scores from fp32 to fp16/bf16
- Tensor rP = flash::convert_type<Element>(tSrS);
- warpgroup_wait<0>();
- // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
- Tensor dS = make_tensor(tdPrdP.data(), scores.layout());
- for (int mi = 0; mi < size<0>(dS); ++mi) {
- #pragma unroll
- for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - tLSErdPsum(mi)); }
- }
- Tensor rdS = flash::convert_type<Element>(tdPrdP);
- // Because of double buffering on dS, we don't need to sync here.
- // Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ.
- // But because both WGs have to sync at the end of the loop and double buffering, this race condition
- // is not possible.
- Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N)
- cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS(_, _, _, smem_pipe_read.index()));
- Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs<TiledMmadKV>(tSrS.layout()));
- flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma_dKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read.index()), tdVrdV);
- Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<TiledMmadKV>(tdPrdP.layout()));
- flash::gemm</*zero_init=*/false, /*wg_wait=*/1>(tiled_mma_dKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK);
- pipeline_do.consumer_release(smem_pipe_read); // release dO
- compute_dQ();
- ++smem_pipe_read;
- }
- }
- CUTLASS_PRAGMA_NO_UNROLL
- for (; m_block < m_block_max; ++m_block) {
- Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select<1, 0>(TileShape_MNK{}));
- pipeline_q.consumer_wait(smem_pipe_read);
- flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma_SdP, tSrK, tSrQ(_, _, _, smem_pipe_read.index()), tSrS);
- Tensor tLSErLSE = make_fragment_like(tLSEsLSE(_, _, _0{}));
- cute::copy(tLSEsLSE(_, _, smem_pipe_read.index()), tLSErLSE);
- Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select<1, 0>(TileShape_MNK{}));
- pipeline_do.consumer_wait(smem_pipe_read);
- flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma_SdP, tdPrV, tdPrdO(_, _, _, smem_pipe_read.index()), tdPrdP);
- warpgroup_wait<1>();
- Tensor cS = cute::make_identity_tensor(select<1, 0>(TileShape_MNK{}));
- Tensor taccScS = thread_mma_SdP.partition_C(cS);
- #pragma unroll
- for (int i = 0; i < size(tSrS); ++i) {
- if (int(get<0>(taccScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; }
- }
- // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
- Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout()));
- // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tLSErLSE); }
- // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(scores); }
- flash::scale_apply_exp2</*Scale_max=*/false, /*Check_inf=*/false>(scores, group_modes<0, 2>(tLSErLSE), params.softmax_scale_log2);
- // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(scores); }
- Tensor tLSErdPsum = make_fragment_like(tLSEsdPsum(_, _, _0{}));
- cute::copy(tLSEsdPsum(_, _, smem_pipe_read.index()), tLSErdPsum);
- // Convert scores from fp32 to fp16/bf16
- Tensor rP = flash::convert_type<Element>(tSrS);
- warpgroup_wait<0>();
- // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
- Tensor dS = make_tensor(tdPrdP.data(), scores.layout());
- #pragma unroll
- for (int mi = 0; mi < size<0>(dS); ++mi) {
- #pragma unroll
- for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - tLSErdPsum(mi)); }
- }
- // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dS); }
- Tensor rdS = flash::convert_type<Element>(tdPrdP);
- Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N)
- cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS(_, _, _, smem_pipe_read.index()));
- Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs<TiledMmadKV>(tSrS.layout()));
- flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma_dKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read.index()), tdVrdV);
- Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<TiledMmadKV>(tdPrdP.layout()));
- flash::gemm</*zero_init=*/false, /*wg_wait=*/1>(tiled_mma_dKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK);
- pipeline_do.consumer_release(smem_pipe_read); // release dO
- compute_dQ();
- ++smem_pipe_read;
- }
- // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tdVrdV); }
- #pragma unroll
- for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.softmax_scale; }
- }
- };
- } // namespace flash
|