123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828 |
- /******************************************************************************
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
- ******************************************************************************/
- #pragma once
- #include "cute/algorithm/copy.hpp"
- #include "cute/atom/mma_atom.hpp"
- #include "cutlass/gemm/collective/collective_builder.hpp"
- #include "cutlass/cutlass.h"
- #include "cutlass/layout/layout.h"
- #include "cutlass/numeric_types.h"
- #include "cutlass/pipeline/pipeline.hpp"
- using namespace cute;
- template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
- class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
- struct SharedStorageQKVO {
- cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
- cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
- union {
- cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
- cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
- };
- struct {
- cute::uint64_t tma_load_mbar[4]; // 4 TMA barriers pre-allocated for usage.
- cutlass::arch::ClusterTransactionBarrier barrier_Q;
- cutlass::arch::ClusterBarrier barrier_O;
- typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
- typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
- int tile_count_semaphore;
- };
- };
- // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
- template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool Is_Q_in_regs_=false,
- int kClusterM_ = 1, typename elem_type=cutlass::half_t>
- struct Flash_fwd_kernel_traits {
- using Element = elem_type;
- using ElementAccum = float;
- using index_t = int64_t;
- using ElementO = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(cutlass::half_t{}, Element{}));
- // The number of threads.
- static constexpr int kNWarps = kNWarps_;
- static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
- static constexpr bool Is_Q_in_regs = Is_Q_in_regs_;
- static_assert(kNWarps_ == 4 || kNWarps_ == 8 || kNWarps_ == 12 || kNWarps_ == 16);
- static constexpr bool Is_WS = kNWarps_ >= 12;
- static_assert(!(Is_WS && Is_Q_in_regs), "Warp-specialization does not support Q in registers");
- static constexpr int kBlockM = kBlockM_;
- static constexpr int kBlockN = kBlockN_;
- static constexpr int kHeadDim = kHeadDim_;
- static_assert(kHeadDim % 32 == 0);
- using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
- static constexpr int kClusterM = kClusterM_;
- using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
- static constexpr int kStages = kStages_;
- using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
- using TiledMma0 = decltype(cute::make_tiled_mma(
- std::conditional_t<
- Is_Q_in_regs,
- decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>()),
- decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>())
- >{},
- AtomLayoutMNK{}));
- using TiledMma1 = decltype(cute::make_tiled_mma(
- cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{})),
- GMMA::Major::K, cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(
- GMMA::Major::K, GMMA::Major::MN)>(),
- AtomLayoutMNK{}));
- using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
- decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
- using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
- using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
- decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
- using SmemLayoutK =
- decltype(tile_to_shape(SmemLayoutAtomK{},
- make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
- using SmemLayoutAtomVFp16 = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
- decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
- using SmemLayoutVFp16 =
- decltype(tile_to_shape(SmemLayoutAtomVFp16{},
- make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
- using SmemLayoutAtomVFp8 = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
- decltype(cute::get<2>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
- using SmemLayoutVFp8 =
- decltype(tile_to_shape(SmemLayoutAtomVFp8{},
- make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{})));
- using SmemLayoutV = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(SmemLayoutVFp8{}, SmemLayoutVFp16{}));
- // Note this is the transpose in terms of the view, not in terms of memory.
- using SmemLayoutVtFp16 =
- decltype(cute::composition(SmemLayoutVFp16{},
- make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
- make_stride(get<1>(TileShape_MNK{}), _1{}, Int<size(SmemLayoutVFp16{}(_, _, _0{}))>{}))));
- using SmemLayoutVt = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(SmemLayoutVFp8{}, SmemLayoutVtFp16{}));
- using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, ElementO,
- decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
- using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
- using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, ElementO>;
- using SharedStorage = SharedStorageQKVO<kStages, Element, Element, ElementO, SmemLayoutQ,
- SmemLayoutK, SmemLayoutV, SmemLayoutO>;
- using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
- using PipelineState = typename cutlass::PipelineState<kStages>;
- // using BarrierType = typename MainloopPipeline::ProducerBarrierType;
- };
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
- class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
- class SmemLayoutdK, class SmemLayoutdV>
- struct SharedStorageQKVdOdKV;
- template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
- class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
- class SmemLayoutdK, class SmemLayoutdV>
- struct SharedStorageQKVdOdKV<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
- SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV> {
- struct {
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
- union {
- struct {
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
- };
- struct {
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
- };
- };
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
- };
- struct {
- cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
- cutlass::arch::ClusterTransactionBarrier barrier_K;
- cutlass::arch::ClusterTransactionBarrier barrier_V;
- typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
- typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
- };
- };
- template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
- class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
- class SmemLayoutdK, class SmemLayoutdV>
- struct SharedStorageQKVdOdKV<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
- SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV> {
- struct {
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
- union {
- struct {
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
- };
- struct {
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
- };
- };
- union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
- };
- };
- struct {
- cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
- cutlass::arch::ClusterTransactionBarrier barrier_K;
- cutlass::arch::ClusterTransactionBarrier barrier_V;
- typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
- typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
- };
- };
- template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
- class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
- class SmemLayoutdK, class SmemLayoutdV>
- struct SharedStorageQKVdOdKVWS;
- template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
- class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
- class SmemLayoutdK, class SmemLayoutdV>
- struct SharedStorageQKVdOdKVWS<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
- SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV> {
- struct {
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
- union {
- struct {
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
- };
- struct {
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
- };
- };
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
- cute::array_aligned<float, cute::cosize_v<SmemLayoutdQacc>> smem_dqacc;
- cute::array_aligned<float, 128> smem_lse;
- cute::array_aligned<float, 128> smem_dpsum;
- };
- struct {
- cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
- cutlass::arch::ClusterTransactionBarrier barrier_K;
- cutlass::arch::ClusterTransactionBarrier barrier_V;
- typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
- typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
- };
- };
- template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
- class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
- class SmemLayoutdK, class SmemLayoutdV>
- struct SharedStorageQKVdOdKVWS<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
- SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV> {
- struct {
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
- union {
- struct {
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
- };
- struct {
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
- };
- };
- union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
- };
- cute::array_aligned<float, cute::cosize_v<SmemLayoutdQacc>> smem_dqacc;
- cute::array_aligned<float, 128> smem_lse;
- cute::array_aligned<float, 128> smem_dpsum;
- };
- struct {
- cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
- cutlass::arch::ClusterTransactionBarrier barrier_K;
- cutlass::arch::ClusterTransactionBarrier barrier_V;
- typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
- typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
- };
- };
- template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
- class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
- class SmemLayoutdQ>
- struct SharedStorageQKVdOdKVSeqqPar;
- template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
- class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
- class SmemLayoutdQ>
- struct SharedStorageQKVdOdKVSeqqPar<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
- SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQ> {
- struct {
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
- union {
- struct {
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
- };
- struct {
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
- };
- };
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
- };
- struct {
- cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
- cutlass::arch::ClusterTransactionBarrier barrier_Q;
- cutlass::arch::ClusterTransactionBarrier barrier_dO;
- typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
- typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
- };
- };
- template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
- class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
- class SmemLayoutdQ>
- struct SharedStorageQKVdOdKVSeqqPar<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
- SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQ> {
- struct {
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
- union {
- struct {
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
- };
- struct {
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
- };
- };
- union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
- };
- };
- struct {
- cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
- cutlass::arch::ClusterTransactionBarrier barrier_Q;
- cutlass::arch::ClusterTransactionBarrier barrier_dO;
- typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
- typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
- };
- };
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
- bool SdP_swapAB_, bool dKV_swapAB_, bool dQ_swapAB_,
- int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
- int kClusterN_ = 1, typename elem_type=cutlass::half_t>
- struct Flash_bwd_kernel_traits {
- using Element = elem_type;
- using ElementAccum = float;
- using index_t = int64_t;
- // The number of threads.
- static constexpr int kNWarps = kNWarps_;
- static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
- static constexpr int kNThreadsNonWS = 8 * cutlass::NumThreadsPerWarp;
- // static constexpr int kNThreadsdQ = cutlass::NumThreadsPerWarpGroup;
- static constexpr int kNThreadsdQ = 2 * cutlass::NumThreadsPerWarpGroup;
- static_assert(kNWarps_ == 8 || kNWarps_ == 12);
- static constexpr bool Is_WS = kNWarps_ >= 12;
- static constexpr int kBlockM = kBlockM_;
- static constexpr int kBlockN = kBlockN_;
- static constexpr int kHeadDim = kHeadDim_;
- static_assert(kHeadDim % 32 == 0);
- using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
- static constexpr int kClusterN = kClusterN_;
- using ClusterShape_MNK = Shape<_1, Int<kClusterN>, _1>;
- static constexpr int kStages = 2;
- static constexpr bool SdP_swapAB = SdP_swapAB_;
- static constexpr bool dKV_swapAB = dKV_swapAB_;
- static constexpr bool dQ_swapAB = dQ_swapAB_;
- static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV
- static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS
- using TileShapeAtomSdP = std::conditional_t<
- !SdP_swapAB,
- Shape<Int<kBlockM>, Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kHeadDim>>,
- Shape<Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kBlockM>, Int<kHeadDim>>
- >;
- using AtomLayoutSdP = std::conditional_t<
- !SdP_swapAB,
- Layout<Shape<Int<AtomLayoutMSdP>, Int<2 / AtomLayoutMSdP>, _1>>,
- Layout<Shape<Int<2 / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>
- >;
- using TiledMmaSdP = decltype(cute::make_tiled_mma(
- cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
- AtomLayoutSdP{}));
- using TileShapeAtomdKV = std::conditional_t<
- !dKV_swapAB,
- Shape<Int<kBlockN>, Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockM>>,
- Shape<Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockN>, Int<kBlockM>>
- >;
- using AtomLayoutdKV = std::conditional_t<
- !dKV_swapAB,
- Layout<Shape<Int<AtomLayoutNdKV>, Int<2 / AtomLayoutNdKV>, _1>>,
- Layout<Shape<Int<2 / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>
- >;
- using TiledMmadKV = decltype(cute::make_tiled_mma(
- std::conditional_t<
- !SdP_swapAB,
- decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::MN, GMMA::Major::MN>()),
- decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::K, GMMA::Major::MN>())
- >{},
- AtomLayoutdKV{}));
- using TileShapeAtomdQ = std::conditional_t<
- !dQ_swapAB,
- Shape<Int<kBlockM>, Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockN>>,
- Shape<Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockM>, Int<kBlockN>>
- // Shape<Int<kBlockM>, Int<kHeadDim >, Int<kBlockN>>,
- // Shape<Int<kHeadDim>, Int<kBlockM>, Int<kBlockN>>
- >;
- using AtomLayoutdQ = std::conditional_t<
- !dQ_swapAB,
- Layout<Shape<Int<AtomLayoutMdQ>, Int<2 / AtomLayoutMdQ>, _1>>,
- Layout<Shape<Int<2 / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>
- // Layout<Shape<Int<1>, Int<1>, _1>>,
- // Layout<Shape<Int<1>, Int<1>, _1>>
- >;
- static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN;
- static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K;
- using TiledMmadQ = decltype(cute::make_tiled_mma(
- std::conditional_t<
- !dQ_swapAB,
- std::conditional_t<
- Mma_dQ_is_RS,
- decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>()),
- decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>())
- >,
- decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::MN, GMMA::Major::K>())
- >{},
- AtomLayoutdQ{}));
- using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
- using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
- using GmemTiledCopydKV = cute::SM90_TMA_STORE;
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
- static constexpr bool Has_cp_async = true;
- #else
- static constexpr bool Has_cp_async = false;
- #endif
- // For the dot_do_o preprocessing kernel
- using Gmem_copy_struct = std::conditional_t<
- Has_cp_async,
- SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
- DefaultCopy
- >;
- static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
- static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
- static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
- // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
- // to affect speed in practice.
- static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
- static_assert(kNThreadsNonWS % kGmemThreadsPerRow == 0, "kNThreadsNonWS must be a multiple of kGmemThreadsPerRow");
- using GmemLayoutAtom = Layout<Shape <Int<kNThreadsNonWS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
- Stride<Int<kGmemThreadsPerRow>, _1>>;
- using GmemLayoutAtomdQ = Layout<Shape <Int<kNThreadsdQ / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
- Stride<Int<kGmemThreadsPerRow>, _1>>;
- using GmemTiledCopydO = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
- GmemLayoutAtom{},
- Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
- using GmemTiledCopydQ = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
- GmemLayoutAtomdQ{},
- Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
- using GmemLayoutAtomdQaccum = std::conditional_t<
- kBlockKSmem == 32,
- Layout<Shape <Int<kNThreadsdQ / 8>, _8>, // Thread layout, 8 threads per row
- Stride< _8, _1>>,
- Layout<Shape <Int<kNThreadsdQ / 16>, _16>, // Thread layout, 16 threads per row
- Stride< _16, _1>>
- >;
- using GmemTiledCopydQaccum = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
- GmemLayoutAtomdQaccum{},
- Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
- using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
- decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
- using SmemLayoutQ =
- decltype(tile_to_shape(SmemLayoutAtomQ{},
- make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
- using SmemLayoutdO = SmemLayoutQ;
- using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
- decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
- using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
- using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
- decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
- using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{})));
- using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
- decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
- using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{})));
- using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
- decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
- using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, select<0, 1>(TileShape_MNK{})));
- // using SmemLayoutAtomdQacc = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, ElementAccum,
- // decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
- // using SmemLayoutdQacc = decltype(tile_to_shape(SmemLayoutAtomdQacc{}, select<0, 2>(TileShape_MNK{})));
- // Note this is the transpose in terms of the view, not in terms of memory.
- using SmemLayoutQt =
- decltype(cute::composition(SmemLayoutQ{},
- make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
- make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
- using SmemLayoutdOt =
- decltype(cute::composition(SmemLayoutdO{},
- make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
- make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
- using SmemLayoutKt =
- decltype(cute::composition(SmemLayoutK{},
- make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
- make_stride(Int<kBlockN>{}, _1{}))));
- using SmemLayoutPt =
- decltype(cute::composition(SmemLayoutP{},
- make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
- make_stride(Int<kBlockM>{}, _1{}))));
- using SmemLayoutdSt =
- decltype(cute::composition(SmemLayoutdS{},
- make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
- make_stride(Int<kBlockM>{}, _1{}))));
- // using SmemLayoutdQacct =
- // decltype(cute::composition(SmemLayoutdQacc{},
- // make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
- // make_stride(Int<kBlockM>{}, _1{}))));
- using SmemLayoutdK = SmemLayoutK;
- using SmemLayoutdV = SmemLayoutV;
- using SmemLayoutdKt = SmemLayoutKt;
- using SmemLayoutdVt = SmemLayoutKt;
- static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
- using SmemLayoutAtomdQ = decltype(
- // composition(Swizzle<kSwizzle, 3, 3>{},
- composition(Swizzle<3, 3, 3>{},
- Layout<Shape<Int<kNThreadsdQ / 32>, Int<32>>,
- Stride<Int<32>, _1>>{}));
- using SmemLayoutdQ = decltype(tile_to_shape(
- SmemLayoutAtomdQ{},
- make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
- using SmemLayoutdQt =
- decltype(cute::composition(SmemLayoutdQ{},
- make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
- make_stride(Int<kBlockM>{}, _1{}))));
- static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
- using SmemLayoutAtomdQaccTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, ElementAccum,
- decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
- using SmemLayoutdQaccTMA = decltype(tile_to_shape(SmemLayoutAtomdQaccTMA{}, select<0, 2>(TileShape_MNK{})));
- using SmemLayoutdQacc = SmemLayoutdQ;
- using SmemLayoutdQacct = SmemLayoutdQt;
- using SmemLayoutdQacc2 = decltype(tile_to_shape(
- SmemLayoutAtomdQ{},
- make_shape(Int<kBlockM>{}, Int<kHeadDim>{}, _2{})));
- // using SmemLayoutdQacc = decltype(tile_to_shape(SmemLayoutAtomdQacc{}, select<0, 2>(TileShape_MNK{})));
- // using SmemLayoutdQacct =
- // decltype(cute::composition(SmemLayoutdQacc{},
- // make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
- // make_stride(Int<kBlockM>{}, _1{}))));
- using RmemTiledCopydQacc = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
- GmemLayoutAtomdQaccum{},
- Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
- // using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
- using SmemCopyAtomPdS = Copy_Atom<
- std::conditional_t<!SdP_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
- Element>;
- using SmemCopyAtomdKV = Copy_Atom<
- std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
- Element>;
- using SmemCopyAtomdQ = Copy_Atom<
- std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
- Element>;
- using SharedStorage = std::conditional_t<
- !Is_WS,
- SharedStorageQKVdOdKV<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
- SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV>,
- SharedStorageQKVdOdKVWS<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
- SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV>
- // SmemLayoutK, SmemLayoutV, SmemLayoutdS, SmemLayoutdQacc2, SmemLayoutdK, SmemLayoutdV>
- >;
- // using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages * 2>;
- // using PipelineState = typename cutlass::PipelineState<kStages * 2>;
- using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
- };
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
- bool SdP_swapAB_, bool dKV_swapAB_, bool dQ_swapAB_,
- int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
- int kClusterN_ = 1, typename elem_type=cutlass::half_t>
- struct Flash_bwd_seqqpar_kernel_traits {
- using Element = elem_type;
- using ElementAccum = float;
- using index_t = int64_t;
- // The number of threads.
- static constexpr int kNWarps = kNWarps_;
- static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
- static_assert(kNWarps_ == 8);
- static constexpr int kBlockM = kBlockM_;
- static constexpr int kBlockN = kBlockN_;
- static constexpr int kHeadDim = kHeadDim_;
- static_assert(kHeadDim % 32 == 0);
- using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
- static constexpr int kClusterN = kClusterN_;
- using ClusterShape_MNK = Shape<_1, Int<kClusterN>, _1>;
- static constexpr int kStages = 2;
- static constexpr bool SdP_swapAB = SdP_swapAB_;
- static constexpr bool dKV_swapAB = dKV_swapAB_;
- static constexpr bool dQ_swapAB = dQ_swapAB_;
- static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV
- static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS
- using TileShapeAtomSdP = std::conditional_t<
- !SdP_swapAB,
- Shape<Int<kBlockM>, Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kHeadDim>>,
- Shape<Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kBlockM>, Int<kHeadDim>>
- >;
- using AtomLayoutSdP = std::conditional_t<
- !SdP_swapAB,
- Layout<Shape<Int<AtomLayoutMSdP>, Int<2 / AtomLayoutMSdP>, _1>>,
- Layout<Shape<Int<2 / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>
- >;
- using TiledMmaSdP = decltype(cute::make_tiled_mma(
- cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
- AtomLayoutSdP{}));
- using TileShapeAtomdKV = std::conditional_t<
- !dKV_swapAB,
- Shape<Int<kBlockN>, Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockM>>,
- Shape<Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockN>, Int<kBlockM>>
- >;
- using AtomLayoutdKV = std::conditional_t<
- !dKV_swapAB,
- Layout<Shape<Int<AtomLayoutNdKV>, Int<2 / AtomLayoutNdKV>, _1>>,
- Layout<Shape<Int<2 / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>
- >;
- using TiledMmadKV = decltype(cute::make_tiled_mma(
- std::conditional_t<
- !SdP_swapAB,
- decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::MN, GMMA::Major::MN>()),
- decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::K, GMMA::Major::MN>())
- >{},
- AtomLayoutdKV{}));
- using TileShapeAtomdQ = std::conditional_t<
- !dQ_swapAB,
- Shape<Int<kBlockM>, Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockN>>,
- Shape<Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockM>, Int<kBlockN>>
- >;
- using AtomLayoutdQ = std::conditional_t<
- !dQ_swapAB,
- Layout<Shape<Int<AtomLayoutMdQ>, Int<2 / AtomLayoutMdQ>, _1>>,
- Layout<Shape<Int<2 / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>
- >;
- static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN;
- static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K;
- using TiledMmadQ = decltype(cute::make_tiled_mma(
- std::conditional_t<
- !dQ_swapAB,
- std::conditional_t<
- Mma_dQ_is_RS,
- decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>()),
- decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>())
- >,
- decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::MN, GMMA::Major::K>())
- >{},
- AtomLayoutdQ{}));
- using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
- using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
- using GmemTiledCopydKV = cute::SM90_TMA_STORE;
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
- static constexpr bool Has_cp_async = true;
- #else
- static constexpr bool Has_cp_async = false;
- #endif
- // For the dot_do_o preprocessing kernel
- using Gmem_copy_struct = std::conditional_t<
- Has_cp_async,
- SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
- DefaultCopy
- >;
- static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
- static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
- static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
- // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
- // to affect speed in practice.
- static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
- static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
- using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
- Stride<Int<kGmemThreadsPerRow>, _1>>;
- using GmemTiledCopydO = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
- GmemLayoutAtom{},
- Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
- using GmemTiledCopydQ = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
- GmemLayoutAtom{},
- Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
- using GmemLayoutAtomdQaccum = std::conditional_t<
- kBlockKSmem == 32,
- Layout<Shape <_32, _8>, // Thread layout, 8 threads per row
- Stride< _8, _1>>,
- Layout<Shape <_16, _16>, // Thread layout, 16 threads per row
- Stride< _16, _1>>
- >;
- using GmemTiledCopydQaccum = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
- GmemLayoutAtomdQaccum{},
- Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
- using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
- decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
- using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
- using SmemLayoutdO = SmemLayoutQ;
- using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
- decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
- using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{},
- make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
- using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
- decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
- using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{},
- make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
- using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
- decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
- using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{})));
- using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
- decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
- using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, select<0, 1>(TileShape_MNK{})));
- // Note this is the transpose in terms of the view, not in terms of memory.
- using SmemLayoutQt =
- decltype(cute::composition(SmemLayoutQ{},
- make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
- make_stride(Int<kBlockM>{}, _1{}))));
- using SmemLayoutdOt =
- decltype(cute::composition(SmemLayoutdO{},
- make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
- make_stride(Int<kBlockM>{}, _1{}))));
- using SmemLayoutKt =
- decltype(cute::composition(SmemLayoutK{},
- make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
- make_stride(Int<kBlockN>{}, _1{}, Int<kBlockN * kHeadDim>{}))));
- using SmemLayoutPt =
- decltype(cute::composition(SmemLayoutP{},
- make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
- make_stride(Int<kBlockM>{}, _1{}))));
- using SmemLayoutdSt =
- decltype(cute::composition(SmemLayoutdS{},
- make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
- make_stride(Int<kBlockM>{}, _1{}))));
- using SmemLayoutdK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
- using SmemLayoutdV = SmemLayoutdK;
- using SmemLayoutdKt = SmemLayoutKt;
- using SmemLayoutdVt = SmemLayoutKt;
- using SmemLayoutdQTMA = decltype(tile_to_shape(SmemLayoutAtomK{}, select<0, 2>(TileShape_MNK{})));
- static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
- using SmemLayoutAtomdQ = decltype(
- composition(Swizzle<kSwizzle, 3, 3>{},
- Layout<Shape<_8, Int<kBlockKSmem>>,
- Stride<Int<kBlockKSmem>, _1>>{}));
- using SmemLayoutdQ = decltype(tile_to_shape(
- SmemLayoutAtomdQ{},
- make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
- using SmemLayoutdQt =
- decltype(cute::composition(SmemLayoutdQ{},
- make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
- make_stride(Int<kBlockM>{}, _1{}))));
- static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
- using SmemLayoutAtomdKV = decltype(
- composition(Swizzle<kSwizzle, 3, 3>{},
- Layout<Shape<_8, Int<kBlockKSmem>>,
- Stride<Int<kBlockKSmem>, _1>>{}));
- using SmemLayoutdKV = decltype(tile_to_shape(
- SmemLayoutAtomdKV{},
- make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
- using SmemLayoutdKVt =
- decltype(cute::composition(SmemLayoutdKV{},
- make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
- make_stride(Int<kBlockN>{}, _1{}))));
- static constexpr int kSmemdKVSize = size(SmemLayoutdKV{}) * sizeof(Element) * 2;
- // using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
- using SmemCopyAtomPdS = Copy_Atom<
- std::conditional_t<!SdP_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
- Element>;
- using SmemCopyAtomdKV = Copy_Atom<
- std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
- Element>;
- using SmemCopyAtomdQ = Copy_Atom<
- std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
- Element>;
- using SharedStorage = SharedStorageQKVdOdKVSeqqPar<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
- SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQTMA>;
- // using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages * 2>;
- // using PipelineState = typename cutlass::PipelineState<kStages * 2>;
- using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
- };
- ////////////////////////////////////////////////////////////////////////////////////////////////////
|