1
0

mainloop_bwd_sm90_tma_gmma_ws.hpp 64 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include <cutlass/array.h>
  7. #include <cutlass/numeric_types.h>
  8. #include <cutlass/numeric_conversion.h>
  9. #include <cutlass/barrier.h>
  10. #include "cutlass/pipeline/pipeline.hpp"
  11. #include "cute/tensor.hpp"
  12. #include "cutlass/gemm/collective/builders/sm90_common.inl"
  13. #include "named_barrier.hpp"
  14. #include "seqlen.h"
  15. #include "mask.h"
  16. #include "softmax.h"
  17. #include "utils.h"
  18. #include "copy_sm90_bulk_reduce.hpp"
  19. namespace flash {
  20. template <bool A, class Mma, class Tensor0>
  21. CUTLASS_DEVICE
  22. auto mma_partition_fragment_AB(Mma const& mma, Tensor0 const& tensor0) {
  23. if constexpr (A) {
  24. return mma.partition_fragment_A(tensor0);
  25. } else {
  26. return mma.partition_fragment_B(tensor0);
  27. }
  28. }
  29. using namespace cute;
  30. template <int Stages, int Stages_dO, int Stages_dS, class ClusterShape_, class TileShape_MNK_, class Element_, class ElementAccum_, class ArchTag_,
  31. bool Is_causal_, bool Is_local_, bool Has_softcap_, bool Varlen_, bool Deterministic,
  32. bool SdP_swapAB_, bool dKV_swapAB_, bool dQ_swapAB_,
  33. int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1>
  34. struct CollectiveMainloopBwd {
  35. static constexpr int kStages = Stages;
  36. static constexpr int kStages_dO = Stages_dO;
  37. static constexpr int kStages_dS = Stages_dS;
  38. static_assert(kStages >= kStages_dO);
  39. static_assert(Stages_dS == 1 || Stages_dS == kStages);
  40. using ClusterShape = ClusterShape_;
  41. using TileShape_MNK = TileShape_MNK_;
  42. using Element = Element_;
  43. using ElementAccum = ElementAccum_;
  44. using ArchTag = ArchTag_;
  45. static constexpr bool Is_causal = Is_causal_;
  46. static constexpr bool Is_local = Is_local_;
  47. static constexpr bool Has_softcap = Has_softcap_;
  48. static constexpr bool Varlen = Varlen_;
  49. using SeqlenInfo_t = flash::SeqlenInfoQK<Varlen, CUTE_STATIC_V(get<0>(TileShape_MNK{}))>;
  50. static constexpr bool SdP_swapAB = SdP_swapAB_;
  51. static constexpr bool dKV_swapAB = dKV_swapAB_;
  52. static constexpr bool dQ_swapAB = dQ_swapAB_;
  53. static constexpr bool Q_dO_same_stages = kStages == kStages_dO;
  54. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  55. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  56. static constexpr int kHeadDim = get<2>(TileShape_MNK{});
  57. static_assert(ArchTag::kMinComputeCapability >= 90);
  58. static_assert(get<0>(ClusterShape{}) == 1 && get<2>(ClusterShape{}) == 1);
  59. static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
  60. static_assert(NumMmaWarpGroups % AtomLayoutMSdP == 0);
  61. static_assert(NumMmaWarpGroups % AtomLayoutNdKV == 0);
  62. static_assert(NumMmaWarpGroups % AtomLayoutMdQ == 0);
  63. static constexpr bool Mma_dKV_is_RS = AtomLayoutMSdP == 1 && AtomLayoutNdKV == NumMmaWarpGroups && SdP_swapAB && !dKV_swapAB;
  64. static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == NumMmaWarpGroups && AtomLayoutMdQ == NumMmaWarpGroups && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS
  65. static constexpr GMMA::Major PdS_Major = GMMA::Major::K;
  66. // static constexpr GMMA::Major PdS_Major = GMMA::Major::MN;
  67. static constexpr GMMA::Major PdSt_Major = PdS_Major == GMMA::Major::K ? GMMA::Major::MN : GMMA::Major::K;
  68. using TileShapeAtomSdP = std::conditional_t<
  69. !SdP_swapAB,
  70. Shape<Int<kBlockM>, Int<kBlockN / (NumMmaWarpGroups / AtomLayoutMSdP)>, Int<kHeadDim>>,
  71. Shape<Int<kBlockN>, Int<kBlockM / AtomLayoutMSdP>, Int<kHeadDim>>
  72. >;
  73. using AtomLayoutSdP = std::conditional_t<
  74. !SdP_swapAB,
  75. Layout<Shape<Int<AtomLayoutMSdP>, Int<NumMmaWarpGroups / AtomLayoutMSdP>, _1>>,
  76. Layout<Shape<Int<NumMmaWarpGroups / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>
  77. >;
  78. using TiledMmaSdP = decltype(cute::make_tiled_mma(
  79. cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
  80. AtomLayoutSdP{}));
  81. using TiledMmadPRS = decltype(cute::make_tiled_mma(
  82. cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
  83. AtomLayoutSdP{}));
  84. using TileShapeAtomdKV = std::conditional_t<
  85. !dKV_swapAB,
  86. Shape<Int<kBlockN>, Int<kHeadDim / (NumMmaWarpGroups / AtomLayoutNdKV)>, Int<kBlockM>>,
  87. Shape<Int<kHeadDim>, Int<kBlockN / AtomLayoutNdKV>, Int<kBlockM>>
  88. >;
  89. using AtomLayoutdKV = std::conditional_t<
  90. !dKV_swapAB,
  91. Layout<Shape<Int<AtomLayoutNdKV>, Int<NumMmaWarpGroups / AtomLayoutNdKV>, _1>>,
  92. Layout<Shape<Int<NumMmaWarpGroups / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>
  93. >;
  94. using TiledMmadKV = decltype(cute::make_tiled_mma(
  95. std::conditional_t<
  96. Mma_dKV_is_RS,
  97. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::K, GMMA::Major::MN>()),
  98. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, !dKV_swapAB ? PdSt_Major : GMMA::Major::MN, !dKV_swapAB ? GMMA::Major::MN : PdSt_Major>())
  99. >{},
  100. AtomLayoutdKV{}));
  101. using TileShapeAtomdQ = std::conditional_t<
  102. !dQ_swapAB,
  103. Shape<Int<kBlockM>, Int<kHeadDim / (NumMmaWarpGroups / AtomLayoutMdQ)>, Int<kBlockN>>,
  104. Shape<Int<kHeadDim>, Int<kBlockM / AtomLayoutMdQ>, Int<kBlockN>>
  105. >;
  106. using AtomLayoutdQ = std::conditional_t<
  107. !dQ_swapAB,
  108. Layout<Shape<Int<AtomLayoutMdQ>, Int<NumMmaWarpGroups / AtomLayoutMdQ>, _1>>,
  109. Layout<Shape<Int<NumMmaWarpGroups / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>
  110. >;
  111. using TiledMmadQ = decltype(cute::make_tiled_mma(
  112. std::conditional_t<
  113. Mma_dQ_is_RS,
  114. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>()),
  115. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, !dQ_swapAB ? PdS_Major : GMMA::Major::MN, !dQ_swapAB ? GMMA::Major::MN : PdS_Major>())
  116. >{},
  117. AtomLayoutdQ{}));
  118. // We need to accommodate both Q and Q^T (and dO and dO^T) in shared memory.
  119. // Q & dO are used in the SdP Mma and Q^T and dO^T are used in the dKV Mma.
  120. // Since this is GMMA::Major::K, the M dimension (kBlockM) doesn't matter for the layout, only the K dimension
  121. // changes the layout.
  122. using SmemLayoutAtomQdO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  123. Int<kBlockM>, Int<kHeadDim / (NumMmaWarpGroups / AtomLayoutNdKV)>>()); // for dKV_Mma
  124. using SmemLayoutQ =
  125. decltype(tile_to_shape(SmemLayoutAtomQdO{},
  126. make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  127. using SmemLayoutdO =
  128. decltype(tile_to_shape(SmemLayoutAtomQdO{},
  129. make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages_dO>{})));
  130. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  131. Int<kBlockN>, Int<kHeadDim / (NumMmaWarpGroups / AtomLayoutMdQ)>>());
  132. using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
  133. using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  134. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  135. using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{})));
  136. using SmemLayoutAtomPdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<PdS_Major, Element,
  137. Int<kBlockM / AtomLayoutMSdP>,
  138. Int<kBlockN / (NumMmaWarpGroups / AtomLayoutMSdP)>>());
  139. using SmemLayoutPdS = decltype(tile_to_shape(
  140. SmemLayoutAtomPdS{},
  141. make_shape(Int<kBlockM>{}, Int<kBlockN>{}, Int<kStages_dS>{}),
  142. std::conditional_t<PdS_Major == GMMA::Major::K, cute::Step<_1, _2, _3>, cute::Step<_2, _1, _3>>{}));
  143. // Need stride to be multiple of 32, otherwise we get error (misaligned address) when doing TMA if e.g. kBlockM=80
  144. // We set stride to be multiple of 64 so that if ShuffleLSE, even if threads read from sLSE but out of bounds,
  145. // it's still a valid smem address.
  146. using SmemLayoutLSE = cute::Layout<cute::Shape<Int<kBlockM>, Int<kStages>>, cute::Stride<_1, Int<cute::round_up(kBlockM, 64)>>>;
  147. using SmemLayoutLSEMma = std::conditional_t<
  148. SdP_swapAB,
  149. cute::Layout<cute::Shape<Int<kBlockN>, Int<kBlockM>, Int<kStages>>, cute::Stride<_0, _1, Int<cute::round_up(kBlockM, 64)>>>,
  150. cute::Layout<cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kStages>>, cute::Stride<_1, _0, Int<cute::round_up(kBlockM, 64)>>>
  151. >;
  152. // Note this is the transpose in terms of the view, not in terms of memory.
  153. using SmemLayoutQt =
  154. decltype(cute::composition(SmemLayoutQ{},
  155. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
  156. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
  157. using SmemLayoutdOt =
  158. decltype(cute::composition(SmemLayoutdO{},
  159. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages_dO>{}),
  160. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
  161. using SmemLayoutKt =
  162. decltype(cute::composition(SmemLayoutK{},
  163. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  164. make_stride(Int<kBlockN>{}, _1{}))));
  165. using SmemLayoutPdSt =
  166. decltype(cute::composition(SmemLayoutPdS{},
  167. make_layout(make_shape(Int<kBlockN>{}, Int<kBlockM>{}, Int<kStages_dS>{}),
  168. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kBlockN>{}))));
  169. // Thread layout, 256 or 384 threads per row
  170. // We split into NumMmaWarpGroups so that we can do Bulk reduce add for each WG separately.
  171. using R2SLayoutAtomdQaccum = Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumMmaWarpGroups>>>;
  172. using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdQaccum{},
  173. Layout<Shape < _4>>{})); // Val layout, 4 vals per store
  174. using SmemLayoutdQaccum = Layout<Shape<Int<kBlockM * kHeadDim / NumMmaWarpGroups>, Int<NumMmaWarpGroups>>>;
  175. static constexpr int kNumPdSStore = kBlockM * kBlockN / NumMmaThreads;
  176. // If !SdP_swapAB, the accum registers hold P / dS, otherwise they hold Pt / dSt.
  177. // If PdS_major is MN, then we need to "transpose" the write.
  178. using SmemCopyAtomPdS = Copy_Atom<
  179. std::conditional_t<(!SdP_swapAB) ^ (PdS_Major == GMMA::Major::MN),
  180. std::conditional_t<kNumPdSStore % 8 == 0, cute::SM90_U32x4_STSM_N, cute::SM90_U32x2_STSM_N>,
  181. std::conditional_t<kNumPdSStore % 8 == 0, cute::SM90_U16x8_STSM_T, cute::SM90_U16x4_STSM_T>
  182. >,
  183. Element
  184. >;
  185. using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape{})));
  186. using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
  187. using ShapeQKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, head, batch)
  188. using StrideQKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
  189. using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen, head, batch)
  190. using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch)
  191. using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
  192. using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;
  193. using TMA_QdO = decltype(make_tma_copy_A_sm90(
  194. GmemTiledCopyQdO{},
  195. make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQKV{}),
  196. take<0, 2>(SmemLayoutQ{}),
  197. TileShape_MNK{},
  198. ClusterShape{})); // mcast along N mode for this M load, if any
  199. using TMA_K = decltype(make_tma_copy_B_sm90(
  200. GmemTiledCopyKV{},
  201. make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQKV{}),
  202. SmemLayoutK{},
  203. TileShape_MNK{},
  204. ClusterShape{})); // no mcast for KV
  205. using TMA_V = decltype(make_tma_copy_B_sm90(
  206. GmemTiledCopyKV{},
  207. make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQKV{}),
  208. SmemLayoutV{},
  209. TileShape_MNK{},
  210. ClusterShape{})); // no mcast for KV
  211. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  212. using PipelineState = typename MainloopPipeline::PipelineState;
  213. using MainloopPipeline_dO = typename cutlass::PipelineTmaAsync<kStages_dO>;
  214. using PipelineState_dO = typename MainloopPipeline_dO::PipelineState;
  215. // Set the bytes transferred in this TMA transaction (may involve multiple issues)
  216. static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutQ{})) * cutlass::sizeof_bits_v<Element> / 8);
  217. static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(SmemLayoutK{}) * cutlass::sizeof_bits_v<Element> / 8);
  218. static constexpr uint32_t TmaTransactionBytesV = static_cast<uint32_t>(size(SmemLayoutV{}) * cutlass::sizeof_bits_v<Element> / 8);
  219. static constexpr uint32_t TmaTransactionBytesLSE = static_cast<uint32_t>(size(select<0>(SmemLayoutLSE{})) * cutlass::sizeof_bits_v<ElementAccum> / 8);
  220. // These are tuned for speed. They don't affect correctness.
  221. // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64
  222. // this helps quite a bit to not have to do causal masking for most of the iterations.
  223. // For hdim 192, separating masking iterations results in register spills.
  224. static constexpr bool SeparateMaskingIterations = kHeadDim <= 64;
  225. // Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share them and then
  226. // shuffle to get the value whenever we need? This can reduce register pressure when SdP_swapAB, where each
  227. // thread needs to keep statistics for (kBlockM / 4) rows. If !SdP_swapAB, each thread only needs to keep
  228. // statistic for 2 rows.
  229. static constexpr bool ShuffleLSE = SdP_swapAB && kHeadDim <= 64;
  230. static constexpr bool ShuffledPSum = SdP_swapAB && kHeadDim <= 64;
  231. // If we have extra registers, we can keep V in registers to reduce smem traffic.
  232. static constexpr bool Mma_dP_is_RS = SdP_swapAB && kHeadDim == 96;
  233. static constexpr bool dQacc_use_TMA = kHeadDim < 256;
  234. // For hdim256, we want to slice the dQ MMA (64 x 256 on 2 WGs) into two (64 x 128 on 2 WGs) so that we can
  235. // do atomic add on one half before doing the other half of the MMA, to reduce register pressure.
  236. static constexpr bool Slice_dQKV_Mma = kHeadDim == 256 && !dQacc_use_TMA && dQ_swapAB && AtomLayoutMdQ == 1 && NumMmaWarpGroups == 2;
  237. static_assert(!(Deterministic && Slice_dQKV_Mma), "Deterministic mode not supported with Slice_dQKV_Mma");
  238. static constexpr size_t SmemAlignmentP = cutlass::detail::alignment_for_swizzle(SmemLayoutPdS{});
  239. static constexpr size_t SmemAlignmentdS = cutlass::detail::alignment_for_swizzle(SmemLayoutPdS{});
  240. // Without this SmemAlignment, with hdim 256 we get "misaligned address" error in TMA
  241. static constexpr size_t SmemAlignmentQKVdO = kHeadDim % 256 == 0 ? 256 : 128;
  242. static constexpr size_t SmemAlignmentV = !Mma_dP_is_RS ? SmemAlignmentQKVdO : cutlass::detail::alignment_for_swizzle(SmemLayoutV{});
  243. static_assert(SmemAlignmentP >= 128 && SmemAlignmentdS >= 128, "Require at least 128B alignment");
  244. // TODO: do we have to worry that smem_dk and smem_dv in the epilogue don't line up w smem_k and smem_v due to alignment?
  245. using SmemdQacc_t = std::conditional_t<!dQacc_use_TMA, cute::array<ElementAccum, 0>, cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdQaccum>>>;
  246. using SmemP_t = std::conditional_t<Mma_dKV_is_RS, cute::array<Element, 0>, cute::array_aligned<Element, cute::cosize_v<SmemLayoutPdS>, SmemAlignmentP>>;
  247. struct TensorStorage : cute::aligned_struct<cute::max(SmemAlignmentP, SmemAlignmentdS, SmemAlignmentQKVdO)> {
  248. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>, SmemAlignmentQKVdO> smem_k;
  249. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>, SmemAlignmentV> smem_v;
  250. SmemdQacc_t smem_dqacc;
  251. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>, SmemAlignmentQKVdO> smem_q;
  252. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>, SmemAlignmentQKVdO> smem_do;
  253. cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutLSE>, 128> smem_lse;
  254. cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutLSE>, 128> smem_dpsum;
  255. SmemP_t smem_p;
  256. cute::array_aligned<Element, cute::cosize_v<SmemLayoutPdS>, SmemAlignmentdS> smem_ds;
  257. };
  258. // Host side kernel arguments
  259. struct Arguments {
  260. Element const* ptr_Q;
  261. ShapeQKV const shape_Q;
  262. StrideQKV const stride_Q;
  263. Element const* ptr_K;
  264. ShapeQKV const shape_K;
  265. StrideQKV const stride_K;
  266. Element const* ptr_V;
  267. StrideQKV const stride_V;
  268. Element const* ptr_dO;
  269. StrideQKV const stride_dO;
  270. ElementAccum* ptr_dQaccum;
  271. ShapedQaccum const shape_dQaccum;
  272. StridedQaccum const stride_dQaccum;
  273. float const* ptr_LSE_log2;
  274. ShapeLSE const shape_LSE;
  275. StrideLSE const stride_LSE_log2;
  276. float const* ptr_dPsum;
  277. StrideLSE const stride_dPsum;
  278. float const softmax_scale;
  279. int const window_size_left, window_size_right, sink_token_length;
  280. float const softcap_val;
  281. int const num_batch;
  282. int* const dq_semaphore;
  283. int const* const cu_seqlens_q = nullptr;
  284. int const* const cu_seqlens_k = nullptr;
  285. int const* const seqused_q = nullptr;
  286. int const* const seqused_k = nullptr;
  287. };
  288. // Device side kernel params
  289. struct Params {
  290. ShapeQKV const shape_Q;
  291. ShapeQKV const shape_K;
  292. ShapedQaccum const shape_dQaccum;
  293. ElementAccum* ptr_dQaccum;
  294. StridedQaccum stride_dQaccum;
  295. cutlass::FastDivmod qhead_per_khead_divmod;
  296. TMA_QdO tma_load_Q, tma_load_dO;
  297. TMA_K tma_load_K;
  298. TMA_V tma_load_V;
  299. float const* ptr_LSE_log2;
  300. ShapeLSE const shape_LSE;
  301. StrideLSE const stride_LSE_log2;
  302. float const* ptr_dPsum;
  303. StrideLSE const stride_dPsum;
  304. float const softmax_scale, softmax_scale_log2;
  305. int const window_size_left, window_size_right, sink_token_length;
  306. float const softcap_val;
  307. int const num_batch;
  308. int* const dq_semaphore;
  309. int const* const cu_seqlens_q = nullptr;
  310. int const* const cu_seqlens_k = nullptr;
  311. int const* const seqused_q = nullptr;
  312. int const* const seqused_k = nullptr;
  313. };
  314. static Params
  315. to_underlying_arguments(Arguments const& args) {
  316. Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q);
  317. TMA_QdO tma_load_Q = make_tma_copy_A_sm90(
  318. GmemTiledCopyQdO{},
  319. mQ,
  320. SmemLayoutQ{}(_, _, _0{}),
  321. TileShape_MNK{},
  322. ClusterShape{}); // mcast along N mode for this M load, if any
  323. Tensor mdO = make_tensor(make_gmem_ptr(args.ptr_dO), args.shape_Q, args.stride_dO);
  324. TMA_QdO tma_load_dO = make_tma_copy_A_sm90(
  325. GmemTiledCopyQdO{},
  326. mdO,
  327. SmemLayoutdO{}(_, _, _0{}),
  328. TileShape_MNK{},
  329. ClusterShape{}); // mcast along N mode for this M load, if any
  330. Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K);
  331. TMA_K tma_load_K = make_tma_copy_B_sm90(
  332. GmemTiledCopyKV{},
  333. mK,
  334. SmemLayoutK{},
  335. TileShape_MNK{},
  336. ClusterShape{}); // no mcast for KV
  337. Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_K, args.stride_V);
  338. TMA_V tma_load_V = make_tma_copy_B_sm90(
  339. GmemTiledCopyKV{},
  340. mV,
  341. SmemLayoutV{},
  342. TileShape_MNK{},
  343. ClusterShape{}); // no mcast for KV
  344. if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); }
  345. // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val.
  346. // Right after this, we multiply by log2(e) before applying exp2.
  347. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val
  348. // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e)
  349. // (assigning it to params.softmax_scale_log2).
  350. // In the backward, we need to multiply by
  351. // (1 - tanh^2) * softmax_scale / softcap_val * softcap_val = (1 - tanh^2) * softmax_scale.
  352. // Instead we multiply by (1 - tanh^2) and multiply dK and dV by params.softmax_scale
  353. // (the original softmax_scale) at the end.
  354. return {args.shape_Q, args.shape_K, args.shape_dQaccum,
  355. args.ptr_dQaccum, args.stride_dQaccum,
  356. cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))),
  357. tma_load_Q, tma_load_dO, tma_load_K, tma_load_V,
  358. args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum,
  359. args.softmax_scale,
  360. !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E),
  361. args.window_size_left, args.window_size_right, args.sink_token_length,
  362. !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val,
  363. args.num_batch, args.dq_semaphore,
  364. args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k};
  365. }
  366. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  367. CUTLASS_DEVICE
  368. static void prefetch_tma_descriptors(Params const& params) {
  369. cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor());
  370. cute::prefetch_tma_descriptor(params.tma_load_dO.get_tma_descriptor());
  371. cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor());
  372. cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor());
  373. }
  374. CUTLASS_DEVICE
  375. cute::tuple<int, int> get_m_block_min_max(Params const& params, SeqlenInfo_t const& seqlen_info,
  376. int n_block, int bidb) {
  377. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  378. int const seqlen_q = seqlen_info.seqlen_q;
  379. int const seqlen_k = seqlen_info.seqlen_k;
  380. int m_block_max = cute::ceil_div(seqlen_q, kBlockM);
  381. if constexpr (Is_local) {
  382. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  383. if (n_block >= cute::ceil_div(params.sink_token_length, kBlockN)) {
  384. m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + params.window_size_left, kBlockM));
  385. }
  386. }
  387. int m_block_min = 0;
  388. if constexpr (Is_causal || Is_local) {
  389. m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - params.window_size_right) / kBlockM);
  390. }
  391. return {m_block_min, m_block_max};
  392. }
  393. template <typename SchedulerPrefetch, typename SharedStorage>
  394. CUTLASS_DEVICE void
  395. load(Params const& params,
  396. MainloopPipeline pipeline_q,
  397. MainloopPipeline_dO pipeline_do,
  398. PipelineState& smem_pipe_write,
  399. PipelineState_dO& smem_pipe_write_do,
  400. SharedStorage &shared_storage,
  401. SchedulerPrefetch const& scheduler_prefetch,
  402. cute::tuple<int32_t, int32_t, int32_t> block_coord
  403. ) {
  404. auto [n_block, bidh, bidb] = block_coord;
  405. SeqlenInfo_t seqlen_info{
  406. bidb, get<0>(params.shape_Q), size<0>(params.shape_K),
  407. params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k
  408. };
  409. auto [m_block_min, m_block_max] = get_m_block_min_max(params, seqlen_info, n_block, bidb);
  410. // It's possible to have m_block_max <= m_block_min. Loading Q, K can cause illegal memory access.
  411. if constexpr (Is_causal || Is_local || Varlen) {
  412. if (m_block_max <= m_block_min) {
  413. scheduler_prefetch();
  414. return;
  415. }
  416. }
  417. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{});
  418. Tensor sdO = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdO{});
  419. Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});
  420. Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{});
  421. Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSE{});
  422. Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSE{});
  423. int bidh_kv = params.qhead_per_khead_divmod.divide(bidh);
  424. // Prepare the TMA loads
  425. uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
  426. constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
  427. uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
  428. bool const is_varlen_q = Varlen && params.cu_seqlens_q;
  429. bool const is_varlen_k = Varlen && params.cu_seqlens_k;
  430. Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0);
  431. Tensor mdO = params.tma_load_dO.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0);
  432. Tensor mK = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0);
  433. Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0);
  434. Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, bidh, !is_varlen_q ? bidb : 0);
  435. Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, bidh, !is_varlen_q ? bidb : 0);
  436. Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
  437. Tensor gdO = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
  438. Tensor gK = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K)
  439. Tensor gV = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K)
  440. Tensor gLSE = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mLSE), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _)
  441. Tensor gdPsum = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mdPsum), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _)
  442. Tensor sK_x = make_tensor(sK.data(), make_layout(sK.layout(), Layout<_1>{}));
  443. Tensor gK_x = make_tensor(gK.data(), make_layout(gK.layout(), Layout<_1>{}));
  444. Tensor sV_x = make_tensor(sV.data(), make_layout(sV.layout(), Layout<_1>{}));
  445. Tensor gV_x = make_tensor(gV.data(), make_layout(gV.layout(), Layout<_1>{}));
  446. // auto [tQgQ, tQsQ] = tma_partition(params.tma_load_Q, block_rank_in_cluster, Layout<ClusterShape>{},
  447. // group_modes<0, 2>(sQ), group_modes<0, 2>(gQ)); // (TMA, k), (TMA, PIPE)
  448. // auto [tdOgdO, tdOsdO] = tma_partition(params.tma_load_dO, block_rank_in_cluster, Layout<ClusterShape>{},
  449. // group_modes<0, 2>(sdO), group_modes<0, 2>(gdO)); // (TMA, k), (TMA, PIPE)
  450. auto block_tma_Q = params.tma_load_Q.get_slice(cluster_local_block_id.y);
  451. auto block_tma_dO = params.tma_load_dO.get_slice(cluster_local_block_id.y);
  452. Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ));
  453. Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ));
  454. Tensor tdOgdO = group_modes<0, 3>(block_tma_dO.partition_S(gdO));
  455. Tensor tdOsdO = group_modes<0, 3>(block_tma_dO.partition_D(sdO));
  456. auto [tKgK, tKsK] = tma_partition(params.tma_load_K, _0{}, Layout<_1>{},
  457. group_modes<0, 2>(sK_x), group_modes<0, 2>(gK_x)); // (TMA), (TMA)
  458. auto [tVgV, tVsV] = tma_partition(params.tma_load_V, _0{}, Layout<_1>{},
  459. group_modes<0, 2>(sV_x), group_modes<0, 2>(gV_x)); // (TMA), (TMA)
  460. auto bulk_copy = Copy_Traits<SM90_BULK_COPY_AUTO>{};
  461. uint16_t mcast_mask_qdo = 0;
  462. if constexpr (cute::is_same_v<GmemTiledCopyQdO, SM90_TMA_LOAD_MULTICAST>) {
  463. auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
  464. for (int n = 0; n < size<1>(block_layout); ++n) {
  465. mcast_mask_qdo |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, _0{}));
  466. }
  467. }
  468. int m_block = m_block_min;
  469. int lane_predicate = cute::elect_one_sync();
  470. if (lane_predicate) {
  471. pipeline_q.producer_acquire(smem_pipe_write);
  472. copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST),
  473. tQgQ(_, m_block), tQsQ(_, smem_pipe_write.index()));
  474. copy(bulk_copy.with(*pipeline_q.producer_get_barrier(smem_pipe_write)),
  475. gLSE(_, m_block), sLSE(_, smem_pipe_write.index()));
  476. }
  477. // // Wait for the MMA warpgroups to say that smem_k and smem_v are ready
  478. // cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::KVEmpty) /*id*/);
  479. if (lane_predicate) {
  480. // Copy K tile and V tile from GMEM to SMEM.
  481. shared_storage.pipelines.barrier_KV.arrive_and_expect_tx(TmaTransactionBytesK + TmaTransactionBytesV);
  482. copy(params.tma_load_K.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.pipelines.barrier_KV), 0 /*mcast_mask*/), tKgK, tKsK);
  483. copy(params.tma_load_V.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.pipelines.barrier_KV), 0 /*mcast_mask*/), tVgV, tVsV);
  484. #pragma unroll (kHeadDim < 256 ? 2 : 1)
  485. for (; m_block < m_block_max - 1; ++m_block) {
  486. // If Q and dO have the same number of stages, we can use the same pipeline state variable
  487. // to reduce registers
  488. PipelineState_dO smem_pipe_write_do_cur = cute::conditional_return<Q_dO_same_stages>(smem_pipe_write, smem_pipe_write_do);
  489. pipeline_do.producer_acquire(smem_pipe_write_do_cur);
  490. copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST),
  491. tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write_do_cur.index()));
  492. copy(bulk_copy.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur)),
  493. gdPsum(_, m_block), sdPsum(_, smem_pipe_write_do_cur.index()));
  494. if constexpr (!Q_dO_same_stages) { ++smem_pipe_write_do; }
  495. ++smem_pipe_write;
  496. pipeline_q.producer_acquire(smem_pipe_write);
  497. copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST),
  498. tQgQ(_, m_block + 1), tQsQ(_, smem_pipe_write.index()));
  499. copy(bulk_copy.with(*pipeline_q.producer_get_barrier(smem_pipe_write)),
  500. gLSE(_, m_block + 1), sLSE(_, smem_pipe_write.index()));
  501. }
  502. }
  503. scheduler_prefetch();
  504. if (lane_predicate) {
  505. PipelineState_dO smem_pipe_write_do_cur = cute::conditional_return<Q_dO_same_stages>(smem_pipe_write, smem_pipe_write_do);
  506. pipeline_do.producer_acquire(smem_pipe_write_do_cur);
  507. copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST),
  508. tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write_do_cur.index()));
  509. copy(bulk_copy.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur)),
  510. gdPsum(_, m_block), sdPsum(_, smem_pipe_write_do_cur.index()));
  511. if constexpr (!Q_dO_same_stages) { ++smem_pipe_write_do; }
  512. ++smem_pipe_write;
  513. }
  514. if constexpr (Q_dO_same_stages) { smem_pipe_write_do = smem_pipe_write; }
  515. }
  516. /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
  517. CUTLASS_DEVICE void
  518. load_tail(MainloopPipeline pipeline_q, MainloopPipeline_dO pipeline_do,
  519. PipelineState& smem_pipe_write) {
  520. static_assert(Q_dO_same_stages, "Q and dO must have the same number of stages");
  521. // Need to copy since pipeline_q.producer_tail(smem_pipe_write) will increment smem_pipe_write
  522. PipelineState smem_pipe_write_do = smem_pipe_write;
  523. // Issue the epilogue waits
  524. if (cute::elect_one_sync()) {
  525. /* This helps avoid early exit of blocks in Cluster
  526. * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
  527. * then would just be acquired since the phase was still inverted from make_producer_start_state
  528. */
  529. pipeline_q.producer_tail(smem_pipe_write);
  530. pipeline_do.producer_tail(smem_pipe_write_do);
  531. }
  532. }
  533. /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
  534. CUTLASS_DEVICE void
  535. load_tail(MainloopPipeline pipeline_q, MainloopPipeline_dO pipeline_do,
  536. PipelineState& smem_pipe_write, PipelineState_dO& smem_pipe_write_do) {
  537. // Issue the epilogue waits
  538. if (cute::elect_one_sync()) {
  539. /* This helps avoid early exit of blocks in Cluster
  540. * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
  541. * then would just be acquired since the phase was still inverted from make_producer_start_state
  542. */
  543. pipeline_q.producer_tail(smem_pipe_write);
  544. pipeline_do.producer_tail(smem_pipe_write_do);
  545. }
  546. }
  547. template <typename SharedStorage>
  548. CUTLASS_DEVICE void
  549. store_dq(Params const& params,
  550. SharedStorage &shared_storage,
  551. cute::tuple<int32_t, int32_t, int32_t> block_coord
  552. ) {
  553. if constexpr (!dQacc_use_TMA) { return; }
  554. auto [n_block, bidh, bidb] = block_coord;
  555. SeqlenInfo_t seqlen_info{
  556. bidb, get<0>(params.shape_Q), size<0>(params.shape_K),
  557. params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k
  558. };
  559. auto [m_block_min, m_block_max] = get_m_block_min_max(params, seqlen_info, n_block, bidb);
  560. // It's possible to have m_block_max <= m_block_min. Exit early
  561. if constexpr (Is_causal || Is_local || Varlen) {
  562. if (m_block_max <= m_block_min) { return; }
  563. }
  564. Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dqacc.data()), SmemLayoutdQaccum{});
  565. static constexpr int dQ_TMA_num_bytes = CUTE_STATIC_V(size<0>(sdQ)) * sizeof(ElementAccum);
  566. bool const is_varlen = Varlen && params.cu_seqlens_q;
  567. Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.ptr_dQaccum)),
  568. params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
  569. Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(_)); // (M * K, _)
  570. Tensor gdQaccum = cute::flat_divide(gdQaccum_, Int<kBlockM * kHeadDim / NumMmaWarpGroups>{}); // (M * K / WG, WG, _)
  571. int const num_batch = params.num_batch;
  572. int const num_head = get<2>(params.shape_Q);
  573. int *lock_ptr = !Deterministic ? nullptr : params.dq_semaphore + bidb * num_head + bidh;
  574. using Barrier = cutlass::GenericBarrier<cutlass::detail::SyncwarpSync>;
  575. bool const lane_predicate = cute::elect_one_sync();
  576. int m_block = m_block_min;
  577. #pragma unroll 2
  578. for (; m_block < m_block_max; ++m_block) {
  579. if constexpr (Deterministic) {
  580. Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block);
  581. }
  582. #pragma unroll
  583. for (int warpgroup_idx = 0; warpgroup_idx < NumMmaWarpGroups; ++warpgroup_idx) {
  584. cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQFullWG1) + warpgroup_idx /*id*/); // sdQ full, to be written to gmem
  585. if (lane_predicate) {
  586. SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdQ(_, warpgroup_idx).data()), raw_pointer_cast(gdQaccum(_, warpgroup_idx, m_block).data()), dQ_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));
  587. tma_store_arrive();
  588. }
  589. }
  590. // Note, the for_each() function is required here to ensure `warpgroup_idx` is of type Int<x>.
  591. for_each(make_int_sequence<NumMmaWarpGroups>{}, [&] (auto warpgroup_idx) {
  592. if (lane_predicate) { tma_store_wait<NumMmaWarpGroups - 1 - CUTE_STATIC_V(warpgroup_idx)>(); }
  593. cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQEmptyWG1) + warpgroup_idx /*id*/); // sdQ empty, ready to be written to
  594. });
  595. if constexpr (Deterministic) {
  596. Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head);
  597. }
  598. }
  599. if constexpr (Is_local && Deterministic) {
  600. constexpr int kBlockM = get<0>(TileShape_MNK{});
  601. int const m_block_global_max = cute::ceil_div(seqlen_info.seqlen_q, kBlockM);
  602. #pragma unroll 2
  603. for (; m_block < m_block_global_max; ++m_block) {
  604. Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head);
  605. }
  606. }
  607. }
  608. CUTLASS_DEVICE void
  609. mma_init() {
  610. // Tell producer (warp 0) that smem_k and smem_v are ready
  611. // We're not currently using this bc we're not using persistent scheduler
  612. // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::KVEmpty) /*id*/);
  613. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  614. if constexpr (dQacc_use_TMA) {
  615. if (warp_idx_in_warpgroup == 0) {
  616. cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQEmptyWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/); // sdQ empty, ready to be written to
  617. }
  618. }
  619. }
  620. template <typename SharedStorage, typename FrgTensordKV>
  621. CUTLASS_DEVICE bool
  622. mma(Params const& params,
  623. MainloopPipeline pipeline_q,
  624. MainloopPipeline_dO pipeline_do,
  625. PipelineState& smem_pipe_read,
  626. PipelineState_dO& smem_pipe_read_do,
  627. FrgTensordKV& tdKrdK,
  628. FrgTensordKV& tdVrdV,
  629. int thread_idx,
  630. int &work_idx,
  631. cute::tuple<int32_t, int32_t, int32_t> block_coord,
  632. SharedStorage& shared_storage
  633. ) {
  634. static_assert(is_rmem<FrgTensordKV>::value, "dK and dV tensor must be rmem resident.");
  635. int n_block = get<0>(block_coord);
  636. int bidb = get<2>(block_coord);
  637. SeqlenInfo_t seqlen_info{
  638. bidb, get<0>(params.shape_Q), size<0>(params.shape_K),
  639. params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k
  640. };
  641. auto [m_block_min, m_block_max] = get_m_block_min_max(params, seqlen_info, n_block, bidb);
  642. // It's possible to have m_block_max <= m_block_min. Exit early
  643. if constexpr (Is_causal || Is_local || Varlen) {
  644. if (m_block_max <= m_block_min) { return false; }
  645. }
  646. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{});
  647. Tensor sdO = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdO{});
  648. Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});
  649. Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{});
  650. Tensor sQt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQt{});
  651. Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdOt{});
  652. Tensor sKt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutKt{});
  653. Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdS{});
  654. Tensor sP_pi = cute::as_position_independent_swizzle_tensor(sP);
  655. Tensor sPt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdSt{});
  656. Tensor sPt_pi = cute::as_position_independent_swizzle_tensor(sPt);
  657. Tensor sdS = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdS{});
  658. Tensor sdS_pi = cute::as_position_independent_swizzle_tensor(sdS);
  659. Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdSt{});
  660. Tensor sdSt_pi = cute::as_position_independent_swizzle_tensor(sdSt);
  661. Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dqacc.data()), SmemLayoutdQaccum{});
  662. Tensor sLSEMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSEMma{});
  663. Tensor sdPsumMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSEMma{});
  664. static_assert(stride<0>(typename TiledMmaSdP::ALayout{}) == 0 and
  665. stride<0>(typename TiledMmaSdP::BLayout{}) == 0 and
  666. size<0>(typename TiledMmaSdP::ALayout{}) == cutlass::NumThreadsPerWarpGroup and
  667. size<0>(typename TiledMmaSdP::BLayout{}) == cutlass::NumThreadsPerWarpGroup,
  668. "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
  669. constexpr int MmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup;
  670. Layout warp_group_thread_layout = make_layout(make_shape(Int<MmaWarpGroups>{}),
  671. make_stride(Int<cutlass::NumThreadsPerWarpGroup>{}));
  672. Layout warp_group_thread_layout_dq = make_layout(make_shape(Int<NumMmaWarpGroups>{}),
  673. make_stride(Int<cutlass::NumThreadsPerWarpGroup>{}));
  674. int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0);
  675. TiledMmaSdP tiled_mma_SdP;
  676. using TiledMmadP = std::conditional_t<!Mma_dP_is_RS, TiledMmaSdP, TiledMmadPRS>;
  677. TiledMmadP tiled_mma_dP;
  678. TiledMmadKV tiled_mma_dKV;
  679. TiledMmadQ tiled_mma_dQ;
  680. auto wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx));
  681. auto wg_mma_dP = tiled_mma_dP.get_slice(warp_group_thread_layout(warp_group_idx));
  682. auto thread_mma_SdP = tiled_mma_SdP.get_thread_slice(thread_idx);
  683. auto wg_mma_dKV = tiled_mma_dKV.get_slice(warp_group_thread_layout(warp_group_idx));
  684. auto wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout_dq(warp_group_idx));
  685. auto thread0_mma_SdP = tiled_mma_SdP.get_thread_slice(_0{});
  686. auto smem_tiled_copy_PdS = make_tiled_copy_C(SmemCopyAtomPdS{}, tiled_mma_SdP);
  687. auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(thread_idx);
  688. R2STiledCopydQaccum r2s_tiled_copy_dQaccum;
  689. auto r2s_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_thread_slice(thread_idx);
  690. Tensor tdQsdQaccum = r2s_thr_copy_dQaccum.partition_D(sdQ);
  691. // if (thread_idx == 0) { print(sdQ); printf("\n"); print(tdQsdQaccum); printf("\n"); }
  692. // Allocate "fragments/descriptors"
  693. // We have to use the templated mma_partition_fragment_AB instead of cute::conditional_return or lambda,
  694. // because some partition_fragment_A/B don't compile.
  695. // https://stackoverflow.com/questions/50051473/if-constexpr-in-c17-does-not-work-in-a-non-templated-function
  696. Tensor tSrQ = mma_partition_fragment_AB</*A=*/!SdP_swapAB>(wg_mma_SdP, sQ);
  697. Tensor tSrK = mma_partition_fragment_AB</*A=*/SdP_swapAB>(wg_mma_SdP, sK);
  698. Tensor tdPrdO = mma_partition_fragment_AB</*A=*/!SdP_swapAB>(wg_mma_SdP, sdO);
  699. Tensor tdPrV = mma_partition_fragment_AB</*A=*/SdP_swapAB>(wg_mma_dP, sV);
  700. Tensor tdVrdO = mma_partition_fragment_AB</*A=*/dKV_swapAB>(wg_mma_dKV, sdOt);
  701. Tensor tdKrQ = mma_partition_fragment_AB</*A=*/dKV_swapAB>(wg_mma_dKV, sQt);
  702. Tensor tdQrdS = mma_partition_fragment_AB</*A=*/!dQ_swapAB>(wg_mma_dQ, sdS);
  703. Tensor tdQrK = mma_partition_fragment_AB</*A=*/dQ_swapAB>(wg_mma_dQ, sKt);
  704. Tensor tPsP = smem_thr_copy_PdS.partition_D(cute::conditional_return<!SdP_swapAB>(sP_pi, sPt_pi)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  705. Tensor tdSsdS = smem_thr_copy_PdS.partition_D(cute::conditional_return<!SdP_swapAB>(sdS_pi, sdSt_pi)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  706. // if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_PdS); print(sP_pi); printf("\n"); print(sPt_pi); printf("\n"); print(tPsP); printf("\n"); print(tdSsdS); printf("\n"); }
  707. // thread_mma_SdP.partition_C(sLSEMma) has shape ((2, 2, V), MMA_M, MMA_N, PIPE), we only take the col indices
  708. // or row indices, depending on whether SdP_swapAB.
  709. Tensor tLSEsLSE = cute::conditional_return<!SdP_swapAB>(
  710. group_modes<0, 2>(thread_mma_SdP.partition_C(sLSEMma)(make_coord(_0{}, _, _0{}), _, _0{}, _)), // (2, MMA_M, PIPE)
  711. group_modes<0, 3>(thread_mma_SdP.partition_C(sLSEMma)(make_coord(_, _0{}, _), _0{}, _, _))); // (2, V, MMA_N, PIPE)
  712. Tensor tLSEsdPsum = cute::conditional_return<!SdP_swapAB>(
  713. group_modes<0, 2>(thread_mma_SdP.partition_C(sdPsumMma)(make_coord(_0{}, _, _0{}), _, _0{}, _)),
  714. group_modes<0, 3>(thread_mma_SdP.partition_C(sdPsumMma)(make_coord(_, _0{}, _), _0{}, _, _)));
  715. // if (blockIdx.x == 0 && threadIdx.x == 128) { print(sLSEMma); printf("\n"); print(tLSEsLSE); printf("\n"); }
  716. // If we want to split the stats among the 8 threads that share the same rows.
  717. static constexpr int kStatsPerThread = cute::ceil_div(decltype(size(tLSEsLSE))::value, 8);
  718. auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
  719. auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
  720. pipeline.consumer_wait(smem_pipe_read, barrier_token);
  721. };
  722. int bidh = get<1>(block_coord);
  723. int const seqlen_q = seqlen_info.seqlen_q;
  724. int const seqlen_k = seqlen_info.seqlen_k;
  725. // For the case where we do atomicAdd directly to gdQaccum instead of using TMA
  726. bool const is_varlen = Varlen && params.cu_seqlens_q;
  727. Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.ptr_dQaccum)),
  728. params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
  729. Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(_)); // (M * K, _)
  730. Tensor gdQaccum = cute::flat_divide(gdQaccum_, Int<kBlockM * kHeadDim / NumMmaWarpGroups>{}); // (M * K / WG, WG, _)
  731. // We can reuse r2s_thr_copy_dQaccum for this partitioning
  732. Tensor tdQgdQaccum = r2s_thr_copy_dQaccum.partition_D(gdQaccum);
  733. // if (blockIdx.x == 0 && threadIdx.x == 128) { print(mdQaccum); printf("\n"); print(gdQaccum_); printf("\n"); print(gdQaccum); printf("\n"); print(tdQgdQaccum); printf("\n"); }
  734. flash::Mask<kBlockM, kBlockN, false /*PackGQA*/, TiledMmaSdP, SdP_swapAB> mask(
  735. thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length,
  736. params.qhead_per_khead_divmod
  737. );
  738. int m_block = m_block_min;
  739. clear(tdKrdK);
  740. clear(tdVrdV);
  741. // tiled_mma_dKV.accumulate_ = GMMA::ScaleOut::Zero;
  742. cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.pipelines.barrier_KV.try_wait(work_idx % 2));
  743. if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.pipelines.barrier_KV.wait(work_idx % 2); }
  744. if constexpr (Mma_dP_is_RS) {
  745. using SmemCopyAtomV = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  746. auto smem_tiled_copy_V = make_tiled_copy_A(SmemCopyAtomV{}, tiled_mma_dP);
  747. auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(thread_idx);
  748. Tensor tdPrV_copy_view = smem_thr_copy_V.retile_D(tdPrV);
  749. Tensor tdPsV_copy_view = smem_thr_copy_V.partition_S(cute::as_position_independent_swizzle_tensor(sV));
  750. cute::copy(smem_tiled_copy_V, tdPsV_copy_view, tdPrV_copy_view);
  751. }
  752. auto bwd_step = [&](int m_block, auto mask_fn) {
  753. Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select<!SdP_swapAB ? 0 : 1, !SdP_swapAB ? 1 : 0>(TileShape_MNK{}));
  754. consumer_wait(pipeline_q, smem_pipe_read);
  755. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1, /*SwapAB=*/SdP_swapAB>(tiled_mma_SdP, tSrQ(_, _, _, smem_pipe_read.index()), tSrK, tSrS);
  756. Tensor tLSErLSE = cute::conditional_return<!ShuffleLSE>(make_fragment_like(tLSEsLSE(_, _0{})), make_tensor<ElementAccum>(Int<kStatsPerThread>{}));
  757. if constexpr (!ShuffleLSE) {
  758. cute::copy(tLSEsLSE(_, smem_pipe_read.index()), tLSErLSE);
  759. } else {
  760. #pragma unroll
  761. for (int i = 0; i < kStatsPerThread; ++i) {
  762. // It's ok to read OOB, since we made sure sLSE is large enough and we won't use the OOB values
  763. tLSErLSE(i) = tLSEsLSE((thread_idx % 32) / 4 + i * 8, smem_pipe_read.index());
  764. }
  765. }
  766. Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select<!SdP_swapAB ? 0 : 1, !SdP_swapAB ? 1 : 0>(TileShape_MNK{}));
  767. PipelineState_dO smem_pipe_read_do_cur = cute::conditional_return<Q_dO_same_stages>(smem_pipe_read, smem_pipe_read_do);
  768. consumer_wait(pipeline_do, smem_pipe_read_do_cur);
  769. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1, /*SwapAB=*/SdP_swapAB>(tiled_mma_dP, tdPrdO(_, _, _, smem_pipe_read_do_cur.index()), tdPrV, tdPrdP);
  770. warpgroup_wait<1>();
  771. if constexpr (Has_softcap) { flash::apply_softcap(tSrS, params.softcap_val); }
  772. // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
  773. Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol</*Transposed=*/SdP_swapAB>(tSrS.layout()));
  774. // dtanh needs to happen before masking, otherwise we get 1 - (-inf)^2 = NaN in the dtanh
  775. auto dtanh = [&] { if constexpr (Has_softcap) return flash::calculate_dtanh(scores); else return nullptr; }();
  776. mask_fn(tSrS, m_block);
  777. #pragma unroll
  778. for (int mi = 0; mi < size<0>(scores); ++mi) {
  779. float const lse_scaled = [&] {
  780. if constexpr (!ShuffleLSE) return tLSErLSE(mi);
  781. else return __shfl_sync(0xffffffff, tLSErLSE(mi / 8), (mi % 8) * 4 + (thread_idx % 4));
  782. }();
  783. #pragma unroll
  784. for (int ni = 0; ni < size<1>(scores); ++ni) {
  785. scores(mi, ni) = exp2f(scores(mi, ni) * params.softmax_scale_log2 - lse_scaled);
  786. }
  787. }
  788. Tensor tLSErdPsum = cute::conditional_return<!ShuffledPSum>(make_fragment_like(tLSEsdPsum(_, _0{})), make_tensor<ElementAccum>(Int<kStatsPerThread>{}));
  789. if constexpr (!ShuffledPSum) {
  790. cute::copy(tLSEsdPsum(_, smem_pipe_read_do_cur.index()), tLSErdPsum);
  791. } else {
  792. #pragma unroll
  793. for (int i = 0; i < kStatsPerThread; ++i) {
  794. tLSErdPsum(i) = tLSEsdPsum((thread_idx % 32) / 4 + i * 8, smem_pipe_read_do_cur.index());
  795. }
  796. }
  797. warpgroup_wait<0>();
  798. // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
  799. Tensor dS = make_tensor(tdPrdP.data(), scores.layout());
  800. #pragma unroll
  801. for (int mi = 0; mi < size<0>(dS); ++mi) {
  802. float const dP_sum_cur = [&] {
  803. if constexpr (!ShuffledPSum) return tLSErdPsum(mi);
  804. else return __shfl_sync(0xffffffff, tLSErdPsum(mi / 8), (mi % 8) * 4 + (thread_idx % 4));
  805. }();
  806. #pragma unroll
  807. for (int ni = 0; ni < size<1>(dS); ++ni) {
  808. dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum_cur);
  809. if constexpr (Has_softcap) { dS(mi, ni) *= dtanh(mi, ni); }
  810. }
  811. }
  812. // Convert scores from fp32 to fp16/bf16
  813. Tensor rP = make_tensor_like<Element>(tSrS);
  814. flash::convert_type_out(tSrS, rP);
  815. if constexpr (!Mma_dKV_is_RS) {
  816. // Need to sync to make sure P has already been used in the previous iteration before writing new values
  817. if constexpr (kStages_dS == 1) {
  818. cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(BwdNamedBarriers::PdS) /*id*/);
  819. }
  820. Tensor tPaP = smem_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N)
  821. cute::copy(smem_tiled_copy_PdS, tPaP, tPsP(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index())));
  822. }
  823. Tensor rdS = make_tensor_like<Element>(tdPrdP);
  824. flash::convert_type_out(tdPrdP, rdS);
  825. // If there's double buffering on dS, we don't need to sync here.
  826. // Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ.
  827. // But because both WGs have to sync at the end of the loop and double buffering,
  828. // this race condition is not possible.
  829. // This sync is to ensure (1) P is written in case of !Mma_dKV_is_RS and
  830. // (2) dS is already read by the Mma in the previous iteration in case of Mma_dKV_is_RS.
  831. if constexpr (!Mma_dKV_is_RS || (kStages_dS == 1 && Mma_dKV_is_RS)) {
  832. cutlass::arch::fence_view_async_shared();
  833. cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(BwdNamedBarriers::PdS) /*id*/);
  834. }
  835. // For hdim 64, It's faster to write to smem_dS first before the dV gemm
  836. Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N)
  837. cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index())));
  838. if constexpr (!Slice_dQKV_Mma) {
  839. // Most cases take this path, except for hdim256 where we want to slice to reduce register pressure
  840. if constexpr (Mma_dKV_is_RS) {
  841. Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs<TiledMmadKV>(tSrS.layout()));
  842. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma_dKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV);
  843. } else {
  844. Tensor tdVrP = mma_partition_fragment_AB</*A=*/!dKV_swapAB>(wg_mma_dKV, sPt);
  845. Tensor tdVrP_cur = tdVrP(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index()));
  846. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1, /*SwapAB=*/dKV_swapAB>(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV);
  847. }
  848. // SMEM fence to make sure sdS is written before it's read by WGMMA
  849. cutlass::arch::fence_view_async_shared();
  850. cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(BwdNamedBarriers::PdS) /*id*/);
  851. Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
  852. Tensor tdQrdS_cur = tdQrdS(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index()));
  853. flash::gemm</*zero_init=*/true, /*wg_wait=*/1, /*SwapAB=*/dQ_swapAB>(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ);
  854. pipeline_do.consumer_release(smem_pipe_read_do_cur); // release dQ
  855. if constexpr (Mma_dKV_is_RS) {
  856. Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<TiledMmadKV>(tdPrdP.layout()));
  857. flash::gemm</*zero_init=*/false, /*wg_wait=*/1>(tiled_mma_dKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK);
  858. } else {
  859. Tensor tdKrdS = mma_partition_fragment_AB</*A=*/!dKV_swapAB>(wg_mma_dKV, sdSt);
  860. Tensor tdKrdS_cur = tdKrdS(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index()));
  861. flash::gemm</*zero_init=*/false, /*wg_wait=*/1, /*SwapAB=*/dKV_swapAB>(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK);
  862. }
  863. if constexpr (dQacc_use_TMA) {
  864. int const warp_group_idx = flash::canonical_warp_group_idx_nosync() - 1;
  865. cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQEmptyWG1) + warp_group_idx /*id*/); // sdQ full, to be written to gmem
  866. Tensor taccdQrdQ = r2s_thr_copy_dQaccum.retile_S(tdQrdQ);
  867. cute::copy(r2s_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum);
  868. cutlass::arch::fence_view_async_shared();
  869. cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQFullWG1) + warp_group_idx /*id*/); // sdQ full, to be written to gmem
  870. } else {
  871. // We can reuse r2s_thr_copy_dQaccum for this partitioning
  872. Tensor tdQrdQ_atomic = recast<float4>(r2s_thr_copy_dQaccum.retile_S(tdQrdQ));
  873. Tensor tdQgdQaccum_atomic = recast<float4>(tdQgdQaccum(_, _, _, m_block));
  874. static_assert(CUTE_STATIC_V(size(tdQrdQ_atomic)) == CUTE_STATIC_V(size(tdQgdQaccum_atomic)));
  875. #pragma unroll
  876. for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); }
  877. }
  878. } else { // Slice_dQKV_Mma
  879. static_assert(!(Slice_dQKV_Mma && Mma_dKV_is_RS));
  880. Tensor tdVrP = mma_partition_fragment_AB</*A=*/!dKV_swapAB>(wg_mma_dKV, sPt);
  881. Tensor tdVrP_cur = tdVrP(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index()));
  882. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1, /*SwapAB=*/dKV_swapAB, /*M_slice=*/0>(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV);
  883. cutlass::arch::fence_view_async_shared();
  884. cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(BwdNamedBarriers::PdS) /*id*/);
  885. Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
  886. Tensor tdQrdS_cur = tdQrdS(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index()));
  887. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1, /*SwapAB=*/dQ_swapAB, /*M_slice=*/0>(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ);
  888. flash::gemm</*zero_init=*/false, /*wg_wait=*/1, /*SwapAB=*/dKV_swapAB, /*M_slice=*/1>(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV);
  889. Tensor tdQrdQ_atomic = recast<float4>(r2s_thr_copy_dQaccum.retile_S(tdQrdQ));
  890. Tensor tdQgdQaccum_atomic = recast<float4>(tdQgdQaccum(_, _, _, m_block));
  891. #pragma unroll
  892. for (int i = 0; i < size(tdQrdQ_atomic) / 2; ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); }
  893. Tensor tdKrdS = mma_partition_fragment_AB</*A=*/!dKV_swapAB>(wg_mma_dKV, sdSt);
  894. Tensor tdKrdS_cur = tdKrdS(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index()));
  895. flash::gemm</*zero_init=*/false, /*wg_wait=*/1, /*SwapAB=*/dKV_swapAB, /*M_slice=*/0>(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK);
  896. pipeline_do.consumer_release(smem_pipe_read_do_cur); // release dO
  897. flash::gemm</*zero_init=*/true, /*wg_wait=*/0, /*SwapAB=*/dQ_swapAB, /*M_slice=*/1>(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ);
  898. #pragma unroll
  899. for (int i = size(tdQrdQ_atomic) / 2; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); }
  900. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1, /*SwapAB=*/dKV_swapAB, /*M_slice=*/1>(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK);
  901. }
  902. warpgroup_wait<0>();
  903. pipeline_q.consumer_release(smem_pipe_read); // release Q
  904. ++smem_pipe_read;
  905. if constexpr (!Q_dO_same_stages) { ++smem_pipe_read_do; }
  906. };
  907. // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64
  908. // this helps quite a bit to not have to do causal masking for most of the iterations.
  909. if constexpr ((Is_causal || Is_local) && SeparateMaskingIterations) {
  910. auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply<true /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };
  911. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  912. int const m_block_masking_max = ((n_block + 1) * kBlockN - 1 + seqlen_q - seqlen_k - params.window_size_right) / kBlockM + 1;
  913. CUTLASS_PRAGMA_NO_UNROLL
  914. for (; m_block < std::min(m_block_max, m_block_masking_max); ++m_block) {
  915. bwd_step(m_block, mask_fn);
  916. }
  917. }
  918. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  919. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  920. int const m_block_max_before_local_mask = !Is_local || !SeparateMaskingIterations
  921. ? m_block_max
  922. : std::min(m_block_max, (n_block * kBlockN + seqlen_q - seqlen_k + params.window_size_left) / kBlockM);
  923. auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply<true /*Seqlenk_mask*/, Is_causal && !SeparateMaskingIterations, Is_local && !SeparateMaskingIterations>(tSrS, m_block, n_block); };
  924. CUTLASS_PRAGMA_NO_UNROLL
  925. for (; m_block < m_block_max_before_local_mask; ++m_block) {
  926. bwd_step(m_block, mask_fn);
  927. }
  928. if constexpr (Is_local && SeparateMaskingIterations) {
  929. auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply<true /*Seqlenk_mask*/, false /*Causal_mask*/, Is_local>(tSrS, m_block, n_block); };
  930. CUTLASS_PRAGMA_NO_UNROLL
  931. for (; m_block < m_block_max; ++m_block) {
  932. bwd_step(m_block, mask_fn);
  933. }
  934. }
  935. // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tdVrdV); }
  936. #pragma unroll
  937. for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.softmax_scale; }
  938. if constexpr (Q_dO_same_stages) { smem_pipe_read_do = smem_pipe_read; }
  939. ++work_idx;
  940. return true;
  941. }
  942. };
  943. } // namespace flash