mainloop_bwd_sm90_tma_gmma_ws.hpp 64 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include <cutlass/array.h>
  7. #include <cutlass/numeric_types.h>
  8. #include <cutlass/numeric_conversion.h>
  9. #include <cutlass/barrier.h>
  10. #include "cutlass/pipeline/pipeline.hpp"
  11. #include "cute/tensor.hpp"
  12. #include "cutlass/gemm/collective/builders/sm90_common.inl"
  13. #include "named_barrier.hpp"
  14. #include "seqlen.h"
  15. #include "mask.h"
  16. #include "softmax.h"
  17. #include "utils.h"
  18. #include "copy_sm90_bulk_reduce.hpp"
  19. namespace flash {
  20. using namespace cute;
  21. template <int Stages, int Stages_dO, int Stages_dS, class ClusterShape_, class TileShape_MNK_, class Element_, class ElementAccum_, class ArchTag_,
  22. bool Is_causal_, bool Is_local_, bool Has_softcap_, bool Varlen_, bool Deterministic,
  23. bool SdP_swapAB_, bool dKV_swapAB_, bool dQ_swapAB_,
  24. int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
  25. bool Mma_dP_is_RS=false>
  26. struct CollectiveMainloopBwdSm90 {
  27. static constexpr int kStages = Stages;
  28. static constexpr int kStages_dO = Stages_dO;
  29. static constexpr int kStages_dS = Stages_dS;
  30. static_assert(kStages >= kStages_dO);
  31. static_assert(Stages_dS == 1 || Stages_dS == kStages);
  32. static_assert(!Mma_dP_is_RS || SdP_swapAB_); // If Mma_dP_is_RS, we need SdP_SwapAB
  33. using ClusterShape = ClusterShape_;
  34. using TileShape_MNK = TileShape_MNK_;
  35. using Element = Element_;
  36. using ElementAccum = ElementAccum_;
  37. using ArchTag = ArchTag_;
  38. static constexpr bool Is_causal = Is_causal_;
  39. static constexpr bool Is_local = Is_local_;
  40. static constexpr bool Has_softcap = Has_softcap_;
  41. static constexpr bool Varlen = Varlen_;
  42. using SeqlenInfo_t = flash::SeqlenInfoQK<Varlen, CUTE_STATIC_V(get<0>(TileShape_MNK{}))>;
  43. static constexpr bool SdP_swapAB = SdP_swapAB_;
  44. static constexpr bool dKV_swapAB = dKV_swapAB_;
  45. static constexpr bool dQ_swapAB = dQ_swapAB_;
  46. static constexpr bool Q_dO_same_stages = kStages == kStages_dO;
  47. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  48. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  49. static constexpr int kHeadDim = get<2>(TileShape_MNK{});
  50. static_assert(ArchTag::kMinComputeCapability >= 90);
  51. static_assert(get<0>(ClusterShape{}) == 1 && get<2>(ClusterShape{}) == 1);
  52. static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
  53. static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarp * 2;
  54. static_assert(NumMmaWarpGroups % AtomLayoutMSdP == 0);
  55. static_assert(NumMmaWarpGroups % AtomLayoutNdKV == 0);
  56. static_assert(NumMmaWarpGroups % AtomLayoutMdQ == 0);
  57. static constexpr bool Mma_dKV_is_RS = AtomLayoutMSdP == 1 && AtomLayoutNdKV == NumMmaWarpGroups && SdP_swapAB && !dKV_swapAB;
  58. static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == NumMmaWarpGroups && AtomLayoutMdQ == NumMmaWarpGroups && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS
  59. static constexpr GMMA::Major PdS_Major = GMMA::Major::K;
  60. // static constexpr GMMA::Major PdS_Major = GMMA::Major::MN;
  61. static constexpr GMMA::Major PdSt_Major = PdS_Major == GMMA::Major::K ? GMMA::Major::MN : GMMA::Major::K;
  62. using TileShapeAtomSdP = std::conditional_t<
  63. !SdP_swapAB,
  64. Shape<Int<kBlockM>, Int<kBlockN / (NumMmaWarpGroups / AtomLayoutMSdP)>, Int<kHeadDim>>,
  65. Shape<Int<kBlockN>, Int<kBlockM / AtomLayoutMSdP>, Int<kHeadDim>>
  66. >;
  67. using AtomLayoutSdP = std::conditional_t<
  68. !SdP_swapAB,
  69. Layout<Shape<Int<AtomLayoutMSdP>, Int<NumMmaWarpGroups / AtomLayoutMSdP>, _1>>,
  70. Layout<Shape<Int<NumMmaWarpGroups / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>
  71. >;
  72. using TiledMmaSdP = decltype(cute::make_tiled_mma(
  73. cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
  74. AtomLayoutSdP{}));
  75. using TiledMmadPRS = decltype(cute::make_tiled_mma(
  76. cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
  77. AtomLayoutSdP{}));
  78. using TileShapeAtomdKV = std::conditional_t<
  79. !dKV_swapAB,
  80. Shape<Int<kBlockN>, Int<kHeadDim / (NumMmaWarpGroups / AtomLayoutNdKV)>, Int<kBlockM>>,
  81. Shape<Int<kHeadDim>, Int<kBlockN / AtomLayoutNdKV>, Int<kBlockM>>
  82. >;
  83. using AtomLayoutdKV = std::conditional_t<
  84. !dKV_swapAB,
  85. Layout<Shape<Int<AtomLayoutNdKV>, Int<NumMmaWarpGroups / AtomLayoutNdKV>, _1>>,
  86. Layout<Shape<Int<NumMmaWarpGroups / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>
  87. >;
  88. using TiledMmadKV = decltype(cute::make_tiled_mma(
  89. std::conditional_t<
  90. Mma_dKV_is_RS,
  91. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::K, GMMA::Major::MN>()),
  92. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, !dKV_swapAB ? PdSt_Major : GMMA::Major::MN, !dKV_swapAB ? GMMA::Major::MN : PdSt_Major>())
  93. >{},
  94. AtomLayoutdKV{}));
  95. using TileShapeAtomdQ = std::conditional_t<
  96. !dQ_swapAB,
  97. Shape<Int<kBlockM>, Int<kHeadDim / (NumMmaWarpGroups / AtomLayoutMdQ)>, Int<kBlockN>>,
  98. Shape<Int<kHeadDim>, Int<kBlockM / AtomLayoutMdQ>, Int<kBlockN>>
  99. >;
  100. using AtomLayoutdQ = std::conditional_t<
  101. !dQ_swapAB,
  102. Layout<Shape<Int<AtomLayoutMdQ>, Int<NumMmaWarpGroups / AtomLayoutMdQ>, _1>>,
  103. Layout<Shape<Int<NumMmaWarpGroups / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>
  104. >;
  105. using TiledMmadQ = decltype(cute::make_tiled_mma(
  106. std::conditional_t<
  107. Mma_dQ_is_RS,
  108. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>()),
  109. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, !dQ_swapAB ? PdS_Major : GMMA::Major::MN, !dQ_swapAB ? GMMA::Major::MN : PdS_Major>())
  110. >{},
  111. AtomLayoutdQ{}));
  112. // We need to accommodate both Q and Q^T (and dO and dO^T) in shared memory.
  113. // Q & dO are used in the SdP Mma and Q^T and dO^T are used in the dKV Mma.
  114. // Since this is GMMA::Major::K, the M dimension (kBlockM) doesn't matter for the layout, only the K dimension
  115. // changes the layout.
  116. using SmemLayoutAtomQdO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  117. Int<kBlockM>, Int<kHeadDim / (NumMmaWarpGroups / AtomLayoutNdKV)>>()); // for dKV_Mma
  118. using SmemLayoutQ =
  119. decltype(tile_to_shape(SmemLayoutAtomQdO{},
  120. make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  121. using SmemLayoutdO =
  122. decltype(tile_to_shape(SmemLayoutAtomQdO{},
  123. make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages_dO>{})));
  124. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  125. Int<kBlockN>, Int<kHeadDim / (NumMmaWarpGroups / AtomLayoutMdQ)>>());
  126. using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
  127. using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  128. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  129. using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{})));
  130. using SmemLayoutAtomPdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<PdS_Major, Element,
  131. Int<kBlockM / AtomLayoutMSdP>,
  132. Int<kBlockN / (NumMmaWarpGroups / AtomLayoutMSdP)>>());
  133. using SmemLayoutPdS = decltype(tile_to_shape(
  134. SmemLayoutAtomPdS{},
  135. make_shape(Int<kBlockM>{}, Int<kBlockN>{}, Int<kStages_dS>{}),
  136. std::conditional_t<PdS_Major == GMMA::Major::K, cute::Step<_1, _2, _3>, cute::Step<_2, _1, _3>>{}));
  137. // Need stride to be multiple of 32, otherwise we get error (misaligned address) when doing TMA if e.g. kBlockM=80
  138. // We set stride to be multiple of 64 so that if ShuffleLSE, even if threads read from sLSE but out of bounds,
  139. // it's still a valid smem address.
  140. using SmemLayoutLSE = cute::Layout<cute::Shape<Int<kBlockM>, Int<kStages>>, cute::Stride<_1, Int<cute::round_up(kBlockM, 64)>>>;
  141. using SmemLayoutLSEMma = std::conditional_t<
  142. SdP_swapAB,
  143. cute::Layout<cute::Shape<Int<kBlockN>, Int<kBlockM>, Int<kStages>>, cute::Stride<_0, _1, Int<cute::round_up(kBlockM, 64)>>>,
  144. cute::Layout<cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kStages>>, cute::Stride<_1, _0, Int<cute::round_up(kBlockM, 64)>>>
  145. >;
  146. // Note this is the transpose in terms of the view, not in terms of memory.
  147. using SmemLayoutQt =
  148. decltype(cute::composition(SmemLayoutQ{},
  149. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
  150. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
  151. using SmemLayoutdOt =
  152. decltype(cute::composition(SmemLayoutdO{},
  153. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages_dO>{}),
  154. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
  155. using SmemLayoutKt =
  156. decltype(cute::composition(SmemLayoutK{},
  157. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  158. make_stride(Int<kBlockN>{}, _1{}))));
  159. using SmemLayoutPdSt =
  160. decltype(cute::composition(SmemLayoutPdS{},
  161. make_layout(make_shape(Int<kBlockN>{}, Int<kBlockM>{}, Int<kStages_dS>{}),
  162. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kBlockN>{}))));
  163. // Thread layout, 256 or 384 threads per row
  164. // We split into NumMmaWarpGroups so that we can do Bulk reduce add for each WG separately.
  165. using R2SLayoutAtomdQaccum = Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumMmaWarpGroups>>>;
  166. using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdQaccum{},
  167. Layout<Shape < _4>>{})); // Val layout, 4 vals per store
  168. using SmemLayoutdQaccum = Layout<Shape<Int<kBlockM * kHeadDim / NumMmaWarpGroups>, Int<NumMmaWarpGroups>>>;
  169. static constexpr int kNumPdSStore = kBlockM * kBlockN / NumMmaThreads;
  170. // If !SdP_swapAB, the accum registers hold P / dS, otherwise they hold Pt / dSt.
  171. // If PdS_major is MN, then we need to "transpose" the write.
  172. using SmemCopyAtomPdS = Copy_Atom<
  173. std::conditional_t<(!SdP_swapAB) ^ (PdS_Major == GMMA::Major::MN),
  174. std::conditional_t<kNumPdSStore % 8 == 0, cute::SM90_U32x4_STSM_N, cute::SM90_U32x2_STSM_N>,
  175. std::conditional_t<kNumPdSStore % 8 == 0, cute::SM90_U16x8_STSM_T, cute::SM90_U16x4_STSM_T>
  176. >,
  177. Element
  178. >;
  179. using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape{})));
  180. using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
  181. using ShapeQKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, head, batch)
  182. using StrideQKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
  183. using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen, head, batch)
  184. using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch)
  185. using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
  186. using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;
  187. using TMA_QdO = decltype(make_tma_copy_A_sm90(
  188. GmemTiledCopyQdO{},
  189. make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQKV{}),
  190. take<0, 2>(SmemLayoutQ{}),
  191. TileShape_MNK{},
  192. ClusterShape{})); // mcast along N mode for this M load, if any
  193. using TMA_K = decltype(make_tma_copy_B_sm90(
  194. GmemTiledCopyKV{},
  195. make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQKV{}),
  196. SmemLayoutK{},
  197. TileShape_MNK{},
  198. ClusterShape{})); // no mcast for KV
  199. using TMA_V = decltype(make_tma_copy_B_sm90(
  200. GmemTiledCopyKV{},
  201. make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQKV{}),
  202. SmemLayoutV{},
  203. TileShape_MNK{},
  204. ClusterShape{})); // no mcast for KV
  205. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  206. using PipelineState = typename MainloopPipeline::PipelineState;
  207. using MainloopPipeline_dO = typename cutlass::PipelineTmaAsync<kStages_dO>;
  208. using PipelineState_dO = typename MainloopPipeline_dO::PipelineState;
  209. // Set the bytes transferred in this TMA transaction (may involve multiple issues)
  210. static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutQ{})) * cutlass::sizeof_bits_v<Element> / 8);
  211. static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(SmemLayoutK{}) * cutlass::sizeof_bits_v<Element> / 8);
  212. static constexpr uint32_t TmaTransactionBytesV = static_cast<uint32_t>(size(SmemLayoutV{}) * cutlass::sizeof_bits_v<Element> / 8);
  213. static constexpr uint32_t TmaTransactionBytesLSE = static_cast<uint32_t>(size(select<0>(SmemLayoutLSE{})) * cutlass::sizeof_bits_v<ElementAccum> / 8);
  214. // These are tuned for speed. They don't affect correctness.
  215. // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64
  216. // this helps quite a bit to not have to do causal masking for most of the iterations.
  217. // For hdim 192, separating masking iterations results in register spills.
  218. static constexpr bool SeparateMaskingIterations = kHeadDim <= 64;
  219. // Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share them and then
  220. // shuffle to get the value whenever we need? This can reduce register pressure when SdP_swapAB, where each
  221. // thread needs to keep statistics for (kBlockM / 4) rows. If !SdP_swapAB, each thread only needs to keep
  222. // statistic for 2 rows.
  223. static constexpr bool ShuffleLSE = SdP_swapAB && kHeadDim <= 64;
  224. static constexpr bool ShuffledPsum = SdP_swapAB && kHeadDim <= 64;
  225. static constexpr bool dQacc_use_TMA = kHeadDim < 256;
  226. // For hdim256, we want to slice the dQ MMA (64 x 256 on 2 WGs) into two (64 x 128 on 2 WGs) so that we can
  227. // do atomic add on one half before doing the other half of the MMA, to reduce register pressure.
  228. static constexpr bool Slice_dQKV_Mma = kHeadDim == 256 && !dQacc_use_TMA && dQ_swapAB && AtomLayoutMdQ == 1 && NumMmaWarpGroups == 2;
  229. static_assert(!(Deterministic && Slice_dQKV_Mma), "Deterministic mode not supported with Slice_dQKV_Mma");
  230. static constexpr size_t SmemAlignmentP = cutlass::detail::alignment_for_swizzle(SmemLayoutPdS{});
  231. static constexpr size_t SmemAlignmentdS = cutlass::detail::alignment_for_swizzle(SmemLayoutPdS{});
  232. // Without this SmemAlignment, with hdim 256 we get "misaligned address" error in TMA
  233. static constexpr size_t SmemAlignmentQKVdO = kHeadDim % 256 == 0 ? 256 : 128;
  234. static constexpr size_t SmemAlignmentV = !Mma_dP_is_RS ? SmemAlignmentQKVdO : cutlass::detail::alignment_for_swizzle(SmemLayoutV{});
  235. static_assert(SmemAlignmentP >= 128 && SmemAlignmentdS >= 128, "Require at least 128B alignment");
  236. // TODO: do we have to worry that smem_dk and smem_dv in the epilogue don't line up w smem_k and smem_v due to alignment?
  237. using SmemdQacc_t = std::conditional_t<!dQacc_use_TMA, cute::array<ElementAccum, 0>, cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdQaccum>>>;
  238. using SmemP_t = std::conditional_t<Mma_dKV_is_RS, cute::array<Element, 0>, cute::array_aligned<Element, cute::cosize_v<SmemLayoutPdS>, SmemAlignmentP>>;
  239. struct TensorStorage : cute::aligned_struct<cute::max(SmemAlignmentP, SmemAlignmentdS, SmemAlignmentQKVdO)> {
  240. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>, SmemAlignmentQKVdO> smem_k;
  241. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>, SmemAlignmentV> smem_v;
  242. SmemdQacc_t smem_dqacc;
  243. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>, SmemAlignmentQKVdO> smem_q;
  244. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>, SmemAlignmentQKVdO> smem_do;
  245. cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutLSE>, 128> smem_lse;
  246. cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutLSE>, 128> smem_dpsum;
  247. SmemP_t smem_p;
  248. cute::array_aligned<Element, cute::cosize_v<SmemLayoutPdS>, SmemAlignmentdS> smem_ds;
  249. };
  250. // Host side kernel arguments
  251. struct Arguments {
  252. Element const* const ptr_Q;
  253. ShapeQKV const shape_Q;
  254. StrideQKV const stride_Q;
  255. Element const* const ptr_K;
  256. ShapeQKV const shape_K;
  257. StrideQKV const stride_K;
  258. Element const* const ptr_V;
  259. StrideQKV const stride_V;
  260. Element const* const ptr_dO;
  261. StrideQKV const stride_dO;
  262. ElementAccum* const ptr_dQaccum;
  263. ShapedQaccum const shape_dQaccum;
  264. StridedQaccum const stride_dQaccum;
  265. float const* const ptr_LSE_log2;
  266. ShapeLSE const shape_LSE;
  267. StrideLSE const stride_LSE_log2;
  268. float const* const ptr_dPsum;
  269. StrideLSE const stride_dPsum;
  270. float const softmax_scale;
  271. int const window_size_left, window_size_right, sink_token_length;
  272. float const softcap_val;
  273. int const num_batch;
  274. int* const dq_semaphore;
  275. int const* const cu_seqlens_q = nullptr;
  276. int const* const cu_seqlens_k = nullptr;
  277. int const* const seqused_q = nullptr;
  278. int const* const seqused_k = nullptr;
  279. };
  280. // Device side kernel params
  281. struct Params {
  282. ShapeQKV const shape_Q;
  283. ShapeQKV const shape_K;
  284. ElementAccum* const ptr_dQaccum;
  285. ShapedQaccum const shape_dQaccum;
  286. StridedQaccum stride_dQaccum;
  287. cutlass::FastDivmod qhead_per_khead_divmod;
  288. TMA_QdO tma_load_Q, tma_load_dO;
  289. TMA_K tma_load_K;
  290. TMA_V tma_load_V;
  291. float const* const ptr_LSE_log2;
  292. ShapeLSE const shape_LSE;
  293. StrideLSE const stride_LSE_log2;
  294. float const* const ptr_dPsum;
  295. StrideLSE const stride_dPsum;
  296. float const softmax_scale, softmax_scale_log2;
  297. int const window_size_left, window_size_right, sink_token_length;
  298. float const softcap_val;
  299. int const num_batch;
  300. int* const dq_semaphore;
  301. int const* const cu_seqlens_q = nullptr;
  302. int const* const cu_seqlens_k = nullptr;
  303. int const* const seqused_q = nullptr;
  304. int const* const seqused_k = nullptr;
  305. };
  306. static Params
  307. to_underlying_arguments(Arguments const& args) {
  308. Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q);
  309. TMA_QdO tma_load_Q = make_tma_copy_A_sm90(
  310. GmemTiledCopyQdO{},
  311. mQ,
  312. SmemLayoutQ{}(_, _, _0{}),
  313. TileShape_MNK{},
  314. ClusterShape{}); // mcast along N mode for this M load, if any
  315. Tensor mdO = make_tensor(make_gmem_ptr(args.ptr_dO), args.shape_Q, args.stride_dO);
  316. TMA_QdO tma_load_dO = make_tma_copy_A_sm90(
  317. GmemTiledCopyQdO{},
  318. mdO,
  319. SmemLayoutdO{}(_, _, _0{}),
  320. TileShape_MNK{},
  321. ClusterShape{}); // mcast along N mode for this M load, if any
  322. Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K);
  323. TMA_K tma_load_K = make_tma_copy_B_sm90(
  324. GmemTiledCopyKV{},
  325. mK,
  326. SmemLayoutK{},
  327. TileShape_MNK{},
  328. ClusterShape{}); // no mcast for KV
  329. Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_K, args.stride_V);
  330. TMA_V tma_load_V = make_tma_copy_B_sm90(
  331. GmemTiledCopyKV{},
  332. mV,
  333. SmemLayoutV{},
  334. TileShape_MNK{},
  335. ClusterShape{}); // no mcast for KV
  336. if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); }
  337. // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val.
  338. // Right after this, we multiply by log2(e) before applying exp2.
  339. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val
  340. // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e)
  341. // (assigning it to params.softmax_scale_log2).
  342. // In the backward, we need to multiply by
  343. // (1 - tanh^2) * softmax_scale / softcap_val * softcap_val = (1 - tanh^2) * softmax_scale.
  344. // Instead we multiply by (1 - tanh^2) and multiply dK and dV by params.softmax_scale
  345. // (the original softmax_scale) at the end.
  346. return {args.shape_Q, args.shape_K,
  347. args.ptr_dQaccum, args.shape_dQaccum, args.stride_dQaccum,
  348. cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))),
  349. tma_load_Q, tma_load_dO, tma_load_K, tma_load_V,
  350. args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum,
  351. args.softmax_scale,
  352. !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E),
  353. args.window_size_left, args.window_size_right, args.sink_token_length,
  354. !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val,
  355. args.num_batch, args.dq_semaphore,
  356. args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k};
  357. }
  358. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  359. CUTLASS_DEVICE
  360. static void prefetch_tma_descriptors(Params const& params) {
  361. cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor());
  362. cute::prefetch_tma_descriptor(params.tma_load_dO.get_tma_descriptor());
  363. cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor());
  364. cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor());
  365. }
  366. CUTLASS_DEVICE
  367. cute::tuple<int, int> get_m_block_min_max(Params const& params, SeqlenInfo_t const& seqlen_info,
  368. int n_block, int bidb) {
  369. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  370. int const seqlen_q = seqlen_info.seqlen_q;
  371. int const seqlen_k = seqlen_info.seqlen_k;
  372. int m_block_max = cute::ceil_div(seqlen_q, kBlockM);
  373. if constexpr (Is_local) {
  374. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  375. if (n_block >= cute::ceil_div(params.sink_token_length, kBlockN)) {
  376. m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + params.window_size_left, kBlockM));
  377. }
  378. }
  379. int m_block_min = 0;
  380. if constexpr (Is_causal || Is_local) {
  381. m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - params.window_size_right) / kBlockM);
  382. }
  383. return {m_block_min, m_block_max};
  384. }
  385. template <typename SchedulerPrefetch, typename SharedStorage>
  386. CUTLASS_DEVICE void
  387. load(Params const& params,
  388. MainloopPipeline pipeline_q,
  389. MainloopPipeline_dO pipeline_do,
  390. PipelineState& smem_pipe_write,
  391. PipelineState_dO& smem_pipe_write_do,
  392. SharedStorage &shared_storage,
  393. SchedulerPrefetch const& scheduler_prefetch,
  394. cute::tuple<int32_t, int32_t, int32_t> block_coord
  395. ) {
  396. auto [n_block, bidh, bidb] = block_coord;
  397. SeqlenInfo_t seqlen_info{
  398. bidb, get<0>(params.shape_Q), size<0>(params.shape_K),
  399. params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k
  400. };
  401. auto [m_block_min, m_block_max] = get_m_block_min_max(params, seqlen_info, n_block, bidb);
  402. // It's possible to have m_block_max <= m_block_min. Loading Q, K can cause illegal memory access.
  403. if constexpr (Is_causal || Is_local || Varlen) {
  404. if (m_block_max <= m_block_min) {
  405. scheduler_prefetch();
  406. return;
  407. }
  408. }
  409. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{});
  410. Tensor sdO = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdO{});
  411. Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});
  412. Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{});
  413. Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSE{});
  414. Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSE{});
  415. int bidh_kv = params.qhead_per_khead_divmod.divide(bidh);
  416. // Prepare the TMA loads
  417. uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
  418. constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
  419. uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
  420. bool const is_varlen_q = Varlen && params.cu_seqlens_q;
  421. bool const is_varlen_k = Varlen && params.cu_seqlens_k;
  422. Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0);
  423. Tensor mdO = params.tma_load_dO.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0);
  424. Tensor mK = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0);
  425. Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0);
  426. Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, bidh, !is_varlen_q ? bidb : 0);
  427. Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, bidh, !is_varlen_q ? bidb : 0);
  428. Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
  429. Tensor gdO = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
  430. Tensor gK = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K)
  431. Tensor gV = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K)
  432. Tensor gLSE = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mLSE), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _)
  433. Tensor gdPsum = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mdPsum), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _)
  434. Tensor sK_x = make_tensor(sK.data(), make_layout(sK.layout(), Layout<_1>{}));
  435. Tensor gK_x = make_tensor(gK.data(), make_layout(gK.layout(), Layout<_1>{}));
  436. Tensor sV_x = make_tensor(sV.data(), make_layout(sV.layout(), Layout<_1>{}));
  437. Tensor gV_x = make_tensor(gV.data(), make_layout(gV.layout(), Layout<_1>{}));
  438. // auto [tQgQ, tQsQ] = tma_partition(params.tma_load_Q, block_rank_in_cluster, Layout<ClusterShape>{},
  439. // group_modes<0, 2>(sQ), group_modes<0, 2>(gQ)); // (TMA, k), (TMA, PIPE)
  440. // auto [tdOgdO, tdOsdO] = tma_partition(params.tma_load_dO, block_rank_in_cluster, Layout<ClusterShape>{},
  441. // group_modes<0, 2>(sdO), group_modes<0, 2>(gdO)); // (TMA, k), (TMA, PIPE)
  442. auto block_tma_Q = params.tma_load_Q.get_slice(cluster_local_block_id.y);
  443. auto block_tma_dO = params.tma_load_dO.get_slice(cluster_local_block_id.y);
  444. Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ));
  445. Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ));
  446. Tensor tdOgdO = group_modes<0, 3>(block_tma_dO.partition_S(gdO));
  447. Tensor tdOsdO = group_modes<0, 3>(block_tma_dO.partition_D(sdO));
  448. auto [tKgK, tKsK] = tma_partition(params.tma_load_K, _0{}, Layout<_1>{},
  449. group_modes<0, 2>(sK_x), group_modes<0, 2>(gK_x)); // (TMA), (TMA)
  450. auto [tVgV, tVsV] = tma_partition(params.tma_load_V, _0{}, Layout<_1>{},
  451. group_modes<0, 2>(sV_x), group_modes<0, 2>(gV_x)); // (TMA), (TMA)
  452. auto bulk_copy = Copy_Traits<SM90_BULK_COPY_AUTO>{};
  453. uint16_t mcast_mask_qdo = 0;
  454. if constexpr (cute::is_same_v<GmemTiledCopyQdO, SM90_TMA_LOAD_MULTICAST>) {
  455. auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
  456. for (int n = 0; n < size<1>(block_layout); ++n) {
  457. mcast_mask_qdo |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, _0{}));
  458. }
  459. }
  460. int m_block = m_block_min;
  461. int lane_predicate = cute::elect_one_sync();
  462. if (lane_predicate) {
  463. pipeline_q.producer_acquire(smem_pipe_write);
  464. copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST),
  465. tQgQ(_, m_block), tQsQ(_, smem_pipe_write.index()));
  466. copy(bulk_copy.with(*pipeline_q.producer_get_barrier(smem_pipe_write)),
  467. gLSE(_, m_block), sLSE(_, smem_pipe_write.index()));
  468. }
  469. // // Wait for the MMA warpgroups to say that smem_k and smem_v are ready
  470. // cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
  471. if (lane_predicate) {
  472. // Copy K tile and V tile from GMEM to SMEM.
  473. shared_storage.pipelines.barrier_KV.arrive_and_expect_tx(TmaTransactionBytesK + TmaTransactionBytesV);
  474. copy(params.tma_load_K.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.pipelines.barrier_KV), 0 /*mcast_mask*/), tKgK, tKsK);
  475. copy(params.tma_load_V.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.pipelines.barrier_KV), 0 /*mcast_mask*/), tVgV, tVsV);
  476. #pragma unroll (kHeadDim < 256 ? 2 : 1)
  477. for (; m_block < m_block_max - 1; ++m_block) {
  478. // If Q and dO have the same number of stages, we can use the same pipeline state variable
  479. // to reduce registers
  480. PipelineState_dO smem_pipe_write_do_cur = cute::conditional_return<Q_dO_same_stages>(smem_pipe_write, smem_pipe_write_do);
  481. pipeline_do.producer_acquire(smem_pipe_write_do_cur);
  482. copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST),
  483. tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write_do_cur.index()));
  484. copy(bulk_copy.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur)),
  485. gdPsum(_, m_block), sdPsum(_, smem_pipe_write_do_cur.index()));
  486. if constexpr (!Q_dO_same_stages) { ++smem_pipe_write_do; }
  487. ++smem_pipe_write;
  488. pipeline_q.producer_acquire(smem_pipe_write);
  489. copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST),
  490. tQgQ(_, m_block + 1), tQsQ(_, smem_pipe_write.index()));
  491. copy(bulk_copy.with(*pipeline_q.producer_get_barrier(smem_pipe_write)),
  492. gLSE(_, m_block + 1), sLSE(_, smem_pipe_write.index()));
  493. }
  494. }
  495. scheduler_prefetch();
  496. if (lane_predicate) {
  497. PipelineState_dO smem_pipe_write_do_cur = cute::conditional_return<Q_dO_same_stages>(smem_pipe_write, smem_pipe_write_do);
  498. pipeline_do.producer_acquire(smem_pipe_write_do_cur);
  499. copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST),
  500. tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write_do_cur.index()));
  501. copy(bulk_copy.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur)),
  502. gdPsum(_, m_block), sdPsum(_, smem_pipe_write_do_cur.index()));
  503. if constexpr (!Q_dO_same_stages) { ++smem_pipe_write_do; }
  504. ++smem_pipe_write;
  505. }
  506. if constexpr (Q_dO_same_stages) { smem_pipe_write_do = smem_pipe_write; }
  507. }
  508. /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
  509. CUTLASS_DEVICE void
  510. load_tail(MainloopPipeline pipeline_q, MainloopPipeline_dO pipeline_do,
  511. PipelineState& smem_pipe_write) {
  512. static_assert(Q_dO_same_stages, "Q and dO must have the same number of stages");
  513. // Need to copy since pipeline_q.producer_tail(smem_pipe_write) will increment smem_pipe_write
  514. PipelineState smem_pipe_write_do = smem_pipe_write;
  515. // Issue the epilogue waits
  516. if (cute::elect_one_sync()) {
  517. /* This helps avoid early exit of blocks in Cluster
  518. * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
  519. * then would just be acquired since the phase was still inverted from make_producer_start_state
  520. */
  521. pipeline_q.producer_tail(smem_pipe_write);
  522. pipeline_do.producer_tail(smem_pipe_write_do);
  523. }
  524. }
  525. /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
  526. CUTLASS_DEVICE void
  527. load_tail(MainloopPipeline pipeline_q, MainloopPipeline_dO pipeline_do,
  528. PipelineState& smem_pipe_write, PipelineState_dO& smem_pipe_write_do) {
  529. // Issue the epilogue waits
  530. if (cute::elect_one_sync()) {
  531. /* This helps avoid early exit of blocks in Cluster
  532. * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
  533. * then would just be acquired since the phase was still inverted from make_producer_start_state
  534. */
  535. pipeline_q.producer_tail(smem_pipe_write);
  536. pipeline_do.producer_tail(smem_pipe_write_do);
  537. }
  538. }
  539. template <typename SharedStorage>
  540. CUTLASS_DEVICE void
  541. store_dq(Params const& params,
  542. SharedStorage &shared_storage,
  543. cute::tuple<int32_t, int32_t, int32_t> block_coord
  544. ) {
  545. if constexpr (!dQacc_use_TMA) { return; }
  546. auto [n_block, bidh, bidb] = block_coord;
  547. SeqlenInfo_t seqlen_info{
  548. bidb, get<0>(params.shape_Q), size<0>(params.shape_K),
  549. params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k
  550. };
  551. auto [m_block_min, m_block_max] = get_m_block_min_max(params, seqlen_info, n_block, bidb);
  552. // It's possible to have m_block_max <= m_block_min. Exit early
  553. if constexpr (Is_causal || Is_local || Varlen) {
  554. if (m_block_max <= m_block_min) { return; }
  555. }
  556. Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dqacc.data()), SmemLayoutdQaccum{});
  557. static constexpr int dQ_TMA_num_bytes = CUTE_STATIC_V(size<0>(sdQ)) * sizeof(ElementAccum);
  558. bool const is_varlen = Varlen && params.cu_seqlens_q;
  559. Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.ptr_dQaccum)),
  560. params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
  561. Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(_)); // (M * K, _)
  562. Tensor gdQaccum = cute::flat_divide(gdQaccum_, Int<kBlockM * kHeadDim / NumMmaWarpGroups>{}); // (M * K / WG, WG, _)
  563. int const num_batch = params.num_batch;
  564. int const num_head = get<2>(params.shape_Q);
  565. int *lock_ptr = !Deterministic ? nullptr : params.dq_semaphore + bidb * num_head + bidh;
  566. using Barrier = cutlass::GenericBarrier<cutlass::detail::SyncwarpSync>;
  567. bool const lane_predicate = cute::elect_one_sync();
  568. int m_block = m_block_min;
  569. #pragma unroll 2
  570. for (; m_block < m_block_max; ++m_block) {
  571. if constexpr (Deterministic) {
  572. Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block);
  573. }
  574. #pragma unroll
  575. for (int warpgroup_idx = 0; warpgroup_idx < NumMmaWarpGroups; ++warpgroup_idx) {
  576. cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::dQFullWG1) + warpgroup_idx /*id*/); // sdQ full, to be written to gmem
  577. if (lane_predicate) {
  578. SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdQ(_, warpgroup_idx).data()), raw_pointer_cast(gdQaccum(_, warpgroup_idx, m_block).data()), dQ_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));
  579. tma_store_arrive();
  580. }
  581. }
  582. // Note, the for_each() function is required here to ensure `warpgroup_idx` is of type Int<x>.
  583. for_each(make_int_sequence<NumMmaWarpGroups>{}, [&] (auto warpgroup_idx) {
  584. if (lane_predicate) { tma_store_wait<NumMmaWarpGroups - 1 - CUTE_STATIC_V(warpgroup_idx)>(); }
  585. cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::dQEmptyWG1) + warpgroup_idx /*id*/); // sdQ empty, ready to be written to
  586. });
  587. if constexpr (Deterministic) {
  588. Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head);
  589. }
  590. }
  591. if constexpr (Is_local && Deterministic) {
  592. constexpr int kBlockM = get<0>(TileShape_MNK{});
  593. int const m_block_global_max = cute::ceil_div(seqlen_info.seqlen_q, kBlockM);
  594. #pragma unroll 2
  595. for (; m_block < m_block_global_max; ++m_block) {
  596. Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head);
  597. }
  598. }
  599. }
  600. CUTLASS_DEVICE void
  601. mma_init() {
  602. // We're not currently using this bc we're not using persistent scheduler
  603. // // Tell producer (warp 0) that smem_k and smem_v are ready
  604. // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
  605. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  606. if constexpr (dQacc_use_TMA) {
  607. if (warp_idx_in_warpgroup == 0) {
  608. cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::dQEmptyWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/); // sdQ empty, ready to be written to
  609. }
  610. }
  611. }
  612. template <typename SharedStorage, typename FrgTensordKV>
  613. CUTLASS_DEVICE bool
  614. mma(Params const& params,
  615. MainloopPipeline pipeline_q,
  616. MainloopPipeline_dO pipeline_do,
  617. PipelineState& smem_pipe_read,
  618. PipelineState_dO& smem_pipe_read_do,
  619. FrgTensordKV& tdKrdK,
  620. FrgTensordKV& tdVrdV,
  621. int thread_idx,
  622. int &work_idx,
  623. cute::tuple<int32_t, int32_t, int32_t> block_coord,
  624. SharedStorage& shared_storage
  625. ) {
  626. static_assert(is_rmem<FrgTensordKV>::value, "dK and dV tensor must be rmem resident.");
  627. int n_block = get<0>(block_coord);
  628. int bidb = get<2>(block_coord);
  629. SeqlenInfo_t seqlen_info{
  630. bidb, get<0>(params.shape_Q), size<0>(params.shape_K),
  631. params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k
  632. };
  633. auto [m_block_min, m_block_max] = get_m_block_min_max(params, seqlen_info, n_block, bidb);
  634. // It's possible to have m_block_max <= m_block_min. Exit early
  635. if constexpr (Is_causal || Is_local || Varlen) {
  636. if (m_block_max <= m_block_min) { return false; }
  637. }
  638. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{});
  639. Tensor sdO = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdO{});
  640. Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{});
  641. Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{});
  642. Tensor sQt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQt{});
  643. Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdOt{});
  644. Tensor sKt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutKt{});
  645. Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdS{});
  646. Tensor sP_pi = cute::as_position_independent_swizzle_tensor(sP);
  647. Tensor sPt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdSt{});
  648. Tensor sPt_pi = cute::as_position_independent_swizzle_tensor(sPt);
  649. Tensor sdS = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdS{});
  650. Tensor sdS_pi = cute::as_position_independent_swizzle_tensor(sdS);
  651. Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdSt{});
  652. Tensor sdSt_pi = cute::as_position_independent_swizzle_tensor(sdSt);
  653. Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dqacc.data()), SmemLayoutdQaccum{});
  654. Tensor sLSEMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSEMma{});
  655. Tensor sdPsumMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSEMma{});
  656. static_assert(stride<0>(typename TiledMmaSdP::ALayout{}) == 0 and
  657. stride<0>(typename TiledMmaSdP::BLayout{}) == 0 and
  658. size<0>(typename TiledMmaSdP::ALayout{}) == cutlass::NumThreadsPerWarpGroup and
  659. size<0>(typename TiledMmaSdP::BLayout{}) == cutlass::NumThreadsPerWarpGroup,
  660. "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
  661. constexpr int MmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup;
  662. Layout warp_group_thread_layout = make_layout(make_shape(Int<MmaWarpGroups>{}),
  663. make_stride(Int<cutlass::NumThreadsPerWarpGroup>{}));
  664. Layout warp_group_thread_layout_dq = make_layout(make_shape(Int<NumMmaWarpGroups>{}),
  665. make_stride(Int<cutlass::NumThreadsPerWarpGroup>{}));
  666. int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0);
  667. TiledMmaSdP tiled_mma_SdP;
  668. using TiledMmadP = std::conditional_t<!Mma_dP_is_RS, TiledMmaSdP, TiledMmadPRS>;
  669. TiledMmadP tiled_mma_dP;
  670. TiledMmadKV tiled_mma_dKV;
  671. TiledMmadQ tiled_mma_dQ;
  672. auto wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx));
  673. auto wg_mma_dP = tiled_mma_dP.get_slice(warp_group_thread_layout(warp_group_idx));
  674. auto thread_mma_SdP = tiled_mma_SdP.get_thread_slice(thread_idx);
  675. auto wg_mma_dKV = tiled_mma_dKV.get_slice(warp_group_thread_layout(warp_group_idx));
  676. auto wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout_dq(warp_group_idx));
  677. auto smem_tiled_copy_PdS = make_tiled_copy_C(SmemCopyAtomPdS{}, tiled_mma_SdP);
  678. auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(thread_idx);
  679. R2STiledCopydQaccum r2s_tiled_copy_dQaccum;
  680. auto r2s_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_thread_slice(thread_idx);
  681. Tensor tdQsdQaccum = r2s_thr_copy_dQaccum.partition_D(sdQ);
  682. // if (thread_idx == 0) { print(sdQ); printf("\n"); print(tdQsdQaccum); printf("\n"); }
  683. // Allocate "fragments/descriptors"
  684. // We have to use the templated mma_partition_fragment_AB instead of cute::conditional_return or lambda,
  685. // because some partition_fragment_A/B don't compile.
  686. // https://stackoverflow.com/questions/50051473/if-constexpr-in-c17-does-not-work-in-a-non-templated-function
  687. Tensor tSrQ = mma_partition_fragment_AB</*A=*/!SdP_swapAB>(wg_mma_SdP, sQ);
  688. Tensor tSrK = mma_partition_fragment_AB</*A=*/SdP_swapAB>(wg_mma_SdP, sK);
  689. Tensor tdPrdO = mma_partition_fragment_AB</*A=*/!SdP_swapAB>(wg_mma_SdP, sdO);
  690. Tensor tdPrV = mma_partition_fragment_AB</*A=*/SdP_swapAB>(wg_mma_dP, sV);
  691. Tensor tdVrdO = mma_partition_fragment_AB</*A=*/dKV_swapAB>(wg_mma_dKV, sdOt);
  692. Tensor tdKrQ = mma_partition_fragment_AB</*A=*/dKV_swapAB>(wg_mma_dKV, sQt);
  693. Tensor tdQrdS = mma_partition_fragment_AB</*A=*/!dQ_swapAB>(wg_mma_dQ, sdS);
  694. Tensor tdQrK = mma_partition_fragment_AB</*A=*/dQ_swapAB>(wg_mma_dQ, sKt);
  695. Tensor tPsP = smem_thr_copy_PdS.partition_D(cute::conditional_return<!SdP_swapAB>(sP_pi, sPt_pi)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  696. Tensor tdSsdS = smem_thr_copy_PdS.partition_D(cute::conditional_return<!SdP_swapAB>(sdS_pi, sdSt_pi)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  697. // if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_PdS); print(sP_pi); printf("\n"); print(sPt_pi); printf("\n"); print(tPsP); printf("\n"); print(tdSsdS); printf("\n"); }
  698. // thread_mma_SdP.partition_C(sLSEMma) has shape ((2, 2, V), MMA_M, MMA_N, PIPE), we only take the col indices
  699. // or row indices, depending on whether SdP_swapAB.
  700. Tensor tLSEsLSE = cute::conditional_return<!SdP_swapAB>(
  701. group_modes<0, 2>(thread_mma_SdP.partition_C(sLSEMma)(make_coord(_0{}, _, _0{}), _, _0{}, _)), // (2, MMA_M, PIPE)
  702. group_modes<0, 3>(thread_mma_SdP.partition_C(sLSEMma)(make_coord(_, _0{}, _), _0{}, _, _))); // (2, V, MMA_N, PIPE)
  703. Tensor tLSEsdPsum = cute::conditional_return<!SdP_swapAB>(
  704. group_modes<0, 2>(thread_mma_SdP.partition_C(sdPsumMma)(make_coord(_0{}, _, _0{}), _, _0{}, _)),
  705. group_modes<0, 3>(thread_mma_SdP.partition_C(sdPsumMma)(make_coord(_, _0{}, _), _0{}, _, _)));
  706. // if (blockIdx.x == 0 && threadIdx.x == 128) { print(sLSEMma); printf("\n"); print(tLSEsLSE); printf("\n"); }
  707. // If we want to split the stats among the 8 threads that share the same rows.
  708. static constexpr int kStatsPerThread = cute::ceil_div(decltype(size(tLSEsLSE))::value, 8);
  709. auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
  710. auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
  711. pipeline.consumer_wait(smem_pipe_read, barrier_token);
  712. };
  713. int bidh = get<1>(block_coord);
  714. int const seqlen_q = seqlen_info.seqlen_q;
  715. int const seqlen_k = seqlen_info.seqlen_k;
  716. // For the case where we do atomicAdd directly to gdQaccum instead of using TMA
  717. bool const is_varlen = Varlen && params.cu_seqlens_q;
  718. Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.ptr_dQaccum)),
  719. params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
  720. Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(_)); // (M * K, _)
  721. Tensor gdQaccum = cute::flat_divide(gdQaccum_, Int<kBlockM * kHeadDim / NumMmaWarpGroups>{}); // (M * K / WG, WG, _)
  722. // We can reuse r2s_thr_copy_dQaccum for this partitioning
  723. Tensor tdQgdQaccum = r2s_thr_copy_dQaccum.partition_D(gdQaccum);
  724. // if (blockIdx.x == 0 && threadIdx.x == 128) { print(mdQaccum); printf("\n"); print(gdQaccum_); printf("\n"); print(gdQaccum); printf("\n"); print(tdQgdQaccum); printf("\n"); }
  725. flash::Mask<kBlockM, kBlockN, false /*PackGQA*/, TiledMmaSdP, SdP_swapAB> mask(
  726. thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length,
  727. params.qhead_per_khead_divmod
  728. );
  729. int m_block = m_block_min;
  730. clear(tdKrdK);
  731. clear(tdVrdV);
  732. // tiled_mma_dKV.accumulate_ = GMMA::ScaleOut::Zero;
  733. cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.pipelines.barrier_KV.try_wait(work_idx % 2));
  734. if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.pipelines.barrier_KV.wait(work_idx % 2); }
  735. if constexpr (Mma_dP_is_RS) {
  736. using SmemCopyAtomV = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  737. auto smem_tiled_copy_V = make_tiled_copy_A(SmemCopyAtomV{}, tiled_mma_dP);
  738. auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(thread_idx);
  739. Tensor tdPrV_copy_view = smem_thr_copy_V.retile_D(tdPrV);
  740. Tensor tdPsV_copy_view = smem_thr_copy_V.partition_S(cute::as_position_independent_swizzle_tensor(sV));
  741. cute::copy(smem_tiled_copy_V, tdPsV_copy_view, tdPrV_copy_view);
  742. }
  743. auto bwd_step = [&](int m_block, auto mask_fn) {
  744. Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select<!SdP_swapAB ? 0 : 1, !SdP_swapAB ? 1 : 0>(TileShape_MNK{}));
  745. consumer_wait(pipeline_q, smem_pipe_read);
  746. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1, /*SwapAB=*/SdP_swapAB>(tiled_mma_SdP, tSrQ(_, _, _, smem_pipe_read.index()), tSrK, tSrS);
  747. Tensor tLSErLSE = cute::conditional_return<!ShuffleLSE>(make_fragment_like(tLSEsLSE(_, _0{})), make_tensor<ElementAccum>(Int<kStatsPerThread>{}));
  748. if constexpr (!ShuffleLSE) {
  749. cute::copy(tLSEsLSE(_, smem_pipe_read.index()), tLSErLSE);
  750. } else {
  751. #pragma unroll
  752. for (int i = 0; i < kStatsPerThread; ++i) {
  753. // It's ok to read OOB, since we made sure sLSE is large enough and we won't use the OOB values
  754. tLSErLSE(i) = tLSEsLSE((thread_idx % 32) / 4 + i * 8, smem_pipe_read.index());
  755. }
  756. }
  757. Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select<!SdP_swapAB ? 0 : 1, !SdP_swapAB ? 1 : 0>(TileShape_MNK{}));
  758. PipelineState_dO smem_pipe_read_do_cur = cute::conditional_return<Q_dO_same_stages>(smem_pipe_read, smem_pipe_read_do);
  759. consumer_wait(pipeline_do, smem_pipe_read_do_cur);
  760. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1, /*SwapAB=*/SdP_swapAB>(tiled_mma_dP, tdPrdO(_, _, _, smem_pipe_read_do_cur.index()), tdPrV, tdPrdP);
  761. warpgroup_wait<1>();
  762. if constexpr (Has_softcap) { flash::apply_softcap(tSrS, params.softcap_val); }
  763. // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
  764. Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol</*Transposed=*/SdP_swapAB>(tSrS.layout()));
  765. // dtanh needs to happen before masking, otherwise we get 1 - (-inf)^2 = NaN in the dtanh
  766. auto dtanh = [&] { if constexpr (Has_softcap) return flash::calculate_dtanh(scores); else return nullptr; }();
  767. mask_fn(tSrS, m_block);
  768. #pragma unroll
  769. for (int mi = 0; mi < size<0>(scores); ++mi) {
  770. float const lse_scaled = [&] {
  771. if constexpr (!ShuffleLSE) return tLSErLSE(mi);
  772. else return __shfl_sync(0xffffffff, tLSErLSE(mi / 8), (mi % 8) * 4 + (thread_idx % 4));
  773. }();
  774. #pragma unroll
  775. for (int ni = 0; ni < size<1>(scores); ++ni) {
  776. scores(mi, ni) = exp2f(scores(mi, ni) * params.softmax_scale_log2 - lse_scaled);
  777. }
  778. }
  779. Tensor tLSErdPsum = cute::conditional_return<!ShuffledPsum>(make_fragment_like(tLSEsdPsum(_, _0{})), make_tensor<ElementAccum>(Int<kStatsPerThread>{}));
  780. if constexpr (!ShuffledPsum) {
  781. cute::copy(tLSEsdPsum(_, smem_pipe_read_do_cur.index()), tLSErdPsum);
  782. } else {
  783. #pragma unroll
  784. for (int i = 0; i < kStatsPerThread; ++i) {
  785. tLSErdPsum(i) = tLSEsdPsum((thread_idx % 32) / 4 + i * 8, smem_pipe_read_do_cur.index());
  786. }
  787. }
  788. warpgroup_wait<0>();
  789. // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
  790. Tensor dS = make_tensor(tdPrdP.data(), scores.layout());
  791. #pragma unroll
  792. for (int mi = 0; mi < size<0>(dS); ++mi) {
  793. float const dP_sum_cur = [&] {
  794. if constexpr (!ShuffledPsum) return tLSErdPsum(mi);
  795. else return __shfl_sync(0xffffffff, tLSErdPsum(mi / 8), (mi % 8) * 4 + (thread_idx % 4));
  796. }();
  797. #pragma unroll
  798. for (int ni = 0; ni < size<1>(dS); ++ni) {
  799. dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum_cur);
  800. if constexpr (Has_softcap) { dS(mi, ni) *= dtanh(mi, ni); }
  801. }
  802. }
  803. // Convert scores from fp32 to fp16/bf16
  804. Tensor rP = make_tensor_like<Element>(tSrS);
  805. flash::convert_type_out(tSrS, rP);
  806. if constexpr (!Mma_dKV_is_RS) {
  807. // Need to sync to make sure P has already been used in the previous iteration before writing new values
  808. if constexpr (kStages_dS == 1) {
  809. cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<uint32_t>(BwdNamedBarriers::PdS) /*id*/);
  810. }
  811. Tensor tPaP = smem_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N)
  812. cute::copy(smem_tiled_copy_PdS, tPaP, tPsP(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index())));
  813. }
  814. Tensor rdS = make_tensor_like<Element>(tdPrdP);
  815. flash::convert_type_out(tdPrdP, rdS);
  816. // If there's double buffering on dS, we don't need to sync here.
  817. // Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ.
  818. // But because both WGs have to sync at the end of the loop and double buffering,
  819. // this race condition is not possible.
  820. // This sync is to ensure (1) P is written in case of !Mma_dKV_is_RS and
  821. // (2) dS is already read by the Mma in the previous iteration in case of Mma_dKV_is_RS.
  822. if constexpr (!Mma_dKV_is_RS || (kStages_dS == 1 && Mma_dKV_is_RS)) {
  823. cutlass::arch::fence_view_async_shared();
  824. cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<uint32_t>(BwdNamedBarriers::PdS) /*id*/);
  825. }
  826. // For hdim 64, It's faster to write to smem_dS first before the dV gemm
  827. Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N)
  828. cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index())));
  829. if constexpr (!Slice_dQKV_Mma) {
  830. // Most cases take this path, except for hdim256 where we want to slice to reduce register pressure
  831. if constexpr (Mma_dKV_is_RS) {
  832. Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs<TiledMmadKV>(tSrS.layout()));
  833. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma_dKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV);
  834. } else {
  835. Tensor tdVrP = mma_partition_fragment_AB</*A=*/!dKV_swapAB>(wg_mma_dKV, sPt);
  836. Tensor tdVrP_cur = tdVrP(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index()));
  837. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1, /*SwapAB=*/dKV_swapAB>(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV);
  838. }
  839. // SMEM fence to make sure sdS is written before it's read by WGMMA
  840. cutlass::arch::fence_view_async_shared();
  841. cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<uint32_t>(BwdNamedBarriers::PdS) /*id*/);
  842. Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
  843. Tensor tdQrdS_cur = tdQrdS(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index()));
  844. flash::gemm</*zero_init=*/true, /*wg_wait=*/1, /*SwapAB=*/dQ_swapAB>(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ);
  845. pipeline_do.consumer_release(smem_pipe_read_do_cur); // release dQ
  846. if constexpr (Mma_dKV_is_RS) {
  847. Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<TiledMmadKV>(tdPrdP.layout()));
  848. flash::gemm</*zero_init=*/false, /*wg_wait=*/1>(tiled_mma_dKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK);
  849. } else {
  850. Tensor tdKrdS = mma_partition_fragment_AB</*A=*/!dKV_swapAB>(wg_mma_dKV, sdSt);
  851. Tensor tdKrdS_cur = tdKrdS(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index()));
  852. flash::gemm</*zero_init=*/false, /*wg_wait=*/1, /*SwapAB=*/dKV_swapAB>(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK);
  853. }
  854. if constexpr (dQacc_use_TMA) {
  855. int const warp_group_idx = flash::canonical_warp_group_idx_nosync() - 1;
  856. cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::dQEmptyWG1) + warp_group_idx /*id*/); // sdQ full, to be written to gmem
  857. Tensor taccdQrdQ = r2s_thr_copy_dQaccum.retile_S(tdQrdQ);
  858. cute::copy(r2s_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum);
  859. cutlass::arch::fence_view_async_shared();
  860. cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::dQFullWG1) + warp_group_idx /*id*/); // sdQ full, to be written to gmem
  861. } else {
  862. // We can reuse r2s_thr_copy_dQaccum for this partitioning
  863. Tensor tdQrdQ_atomic = recast<float4>(r2s_thr_copy_dQaccum.retile_S(tdQrdQ));
  864. Tensor tdQgdQaccum_atomic = recast<float4>(tdQgdQaccum(_, _, _, m_block));
  865. static_assert(CUTE_STATIC_V(size(tdQrdQ_atomic)) == CUTE_STATIC_V(size(tdQgdQaccum_atomic)));
  866. #pragma unroll
  867. for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); }
  868. }
  869. } else { // Slice_dQKV_Mma
  870. static_assert(!(Slice_dQKV_Mma && Mma_dKV_is_RS));
  871. Tensor tdVrP = mma_partition_fragment_AB</*A=*/!dKV_swapAB>(wg_mma_dKV, sPt);
  872. Tensor tdVrP_cur = tdVrP(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index()));
  873. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1, /*SwapAB=*/dKV_swapAB, /*M_slice=*/0>(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV);
  874. cutlass::arch::fence_view_async_shared();
  875. cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<uint32_t>(BwdNamedBarriers::PdS) /*id*/);
  876. Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
  877. Tensor tdQrdS_cur = tdQrdS(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index()));
  878. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1, /*SwapAB=*/dQ_swapAB, /*M_slice=*/0>(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ);
  879. flash::gemm</*zero_init=*/false, /*wg_wait=*/1, /*SwapAB=*/dKV_swapAB, /*M_slice=*/1>(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV);
  880. Tensor tdQrdQ_atomic = recast<float4>(r2s_thr_copy_dQaccum.retile_S(tdQrdQ));
  881. Tensor tdQgdQaccum_atomic = recast<float4>(tdQgdQaccum(_, _, _, m_block));
  882. #pragma unroll
  883. for (int i = 0; i < size(tdQrdQ_atomic) / 2; ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); }
  884. Tensor tdKrdS = mma_partition_fragment_AB</*A=*/!dKV_swapAB>(wg_mma_dKV, sdSt);
  885. Tensor tdKrdS_cur = tdKrdS(_, _, _, cute::conditional_return<kStages_dS==1>(_0{}, smem_pipe_read.index()));
  886. flash::gemm</*zero_init=*/false, /*wg_wait=*/1, /*SwapAB=*/dKV_swapAB, /*M_slice=*/0>(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK);
  887. pipeline_do.consumer_release(smem_pipe_read_do_cur); // release dO
  888. flash::gemm</*zero_init=*/true, /*wg_wait=*/0, /*SwapAB=*/dQ_swapAB, /*M_slice=*/1>(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ);
  889. #pragma unroll
  890. for (int i = size(tdQrdQ_atomic) / 2; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); }
  891. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1, /*SwapAB=*/dKV_swapAB, /*M_slice=*/1>(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK);
  892. }
  893. warpgroup_wait<0>();
  894. pipeline_q.consumer_release(smem_pipe_read); // release Q
  895. ++smem_pipe_read;
  896. if constexpr (!Q_dO_same_stages) { ++smem_pipe_read_do; }
  897. };
  898. // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64
  899. // this helps quite a bit to not have to do causal masking for most of the iterations.
  900. if constexpr ((Is_causal || Is_local) && SeparateMaskingIterations) {
  901. auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply<true /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };
  902. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  903. int const m_block_masking_max = ((n_block + 1) * kBlockN - 1 + seqlen_q - seqlen_k - params.window_size_right) / kBlockM + 1;
  904. CUTLASS_PRAGMA_NO_UNROLL
  905. for (; m_block < std::min(m_block_max, m_block_masking_max); ++m_block) {
  906. bwd_step(m_block, mask_fn);
  907. }
  908. }
  909. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  910. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  911. int const m_block_max_before_local_mask = !Is_local || !SeparateMaskingIterations
  912. ? m_block_max
  913. : std::min(m_block_max, (n_block * kBlockN + seqlen_q - seqlen_k + params.window_size_left) / kBlockM);
  914. auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply<true /*Seqlenk_mask*/, Is_causal && !SeparateMaskingIterations, Is_local && !SeparateMaskingIterations>(tSrS, m_block, n_block); };
  915. CUTLASS_PRAGMA_NO_UNROLL
  916. for (; m_block < m_block_max_before_local_mask; ++m_block) {
  917. bwd_step(m_block, mask_fn);
  918. }
  919. if constexpr (Is_local && SeparateMaskingIterations) {
  920. auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply<true /*Seqlenk_mask*/, false /*Causal_mask*/, Is_local>(tSrS, m_block, n_block); };
  921. CUTLASS_PRAGMA_NO_UNROLL
  922. for (; m_block < m_block_max; ++m_block) {
  923. bwd_step(m_block, mask_fn);
  924. }
  925. }
  926. // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tdVrdV); }
  927. #pragma unroll
  928. for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.softmax_scale; }
  929. if constexpr (Q_dO_same_stages) { smem_pipe_read_do = smem_pipe_read; }
  930. ++work_idx;
  931. return true;
  932. }
  933. };
  934. } // namespace flash