mainloop_bwd_sm90_tma_gmma_ws.hpp 48 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include <cutlass/array.h>
  7. #include <cutlass/numeric_types.h>
  8. #include <cutlass/numeric_conversion.h>
  9. #include <cutlass/barrier.h>
  10. #include "cutlass/pipeline/pipeline.hpp"
  11. #include "cute/tensor.hpp"
  12. #include "cutlass/gemm/collective/collective_builder.hpp"
  13. #include "named_barrier.hpp"
  14. #include "softmax.h"
  15. #include "utils.h"
  16. namespace flash {
  17. using namespace cute;
  18. template <int Stages, class ClusterShape_, class TileShape_MNK_, class Element_, class ElementAccum_, class ArchTag_,
  19. bool Is_causal_, bool Varlen_, bool Deterministic,
  20. bool dKV_swapAB_, bool dQ_swapAB_,
  21. int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1>
  22. struct CollectiveMainloopBwd {
  23. static constexpr int kStages = Stages;
  24. using ClusterShape = ClusterShape_;
  25. using TileShape_MNK = TileShape_MNK_;
  26. using Element = Element_;
  27. using ElementAccum = ElementAccum_;
  28. using ArchTag = ArchTag_;
  29. static constexpr bool Is_causal = Is_causal_;
  30. static constexpr bool Varlen = Varlen_;
  31. static constexpr bool SdP_swapAB = true;
  32. static constexpr bool dKV_swapAB = dKV_swapAB_;
  33. static constexpr bool dQ_swapAB = dQ_swapAB_;
  34. static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV
  35. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  36. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  37. static constexpr int kHeadDim = get<2>(TileShape_MNK{});
  38. static constexpr int NumdQWarpGroups = 2;
  39. static constexpr int kNThreadsdQ = NumdQWarpGroups * cutlass::NumThreadsPerWarpGroup;
  40. static_assert(ArchTag::kMinComputeCapability >= 90);
  41. static_assert(get<0>(ClusterShape{}) == 1 && get<2>(ClusterShape{}) == 1);
  42. static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS
  43. using TileShapeAtomSdP = std::conditional_t<
  44. !SdP_swapAB,
  45. Shape<Int<kBlockM>, Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kHeadDim>>,
  46. Shape<Int<kBlockN>, Int<kBlockM / AtomLayoutMSdP>, Int<kHeadDim>>
  47. >;
  48. using AtomLayoutSdP = std::conditional_t<
  49. !SdP_swapAB,
  50. Layout<Shape<Int<AtomLayoutMSdP>, Int<2 / AtomLayoutMSdP>, _1>>,
  51. Layout<Shape<Int<2 / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>
  52. >;
  53. using TiledMmaSdP = decltype(cute::make_tiled_mma(
  54. cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
  55. AtomLayoutSdP{}));
  56. using TileShapeAtomdKV = std::conditional_t<
  57. !dKV_swapAB,
  58. Shape<Int<kBlockN>, Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockM>>,
  59. Shape<Int<kHeadDim>, Int<kBlockN / AtomLayoutNdKV>, Int<kBlockM>>
  60. >;
  61. using AtomLayoutdKV = std::conditional_t<
  62. !dKV_swapAB,
  63. Layout<Shape<Int<AtomLayoutNdKV>, Int<2 / AtomLayoutNdKV>, _1>>,
  64. Layout<Shape<Int<2 / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>
  65. >;
  66. using TiledMmadKV = decltype(cute::make_tiled_mma(
  67. std::conditional_t<
  68. !SdP_swapAB,
  69. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::MN, GMMA::Major::MN>()),
  70. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::K, GMMA::Major::MN>())
  71. >{},
  72. AtomLayoutdKV{}));
  73. using TileShapeAtomdQ = std::conditional_t<
  74. !dQ_swapAB,
  75. Shape<Int<kBlockM>, Int<kHeadDim / (NumdQWarpGroups / AtomLayoutMdQ)>, Int<kBlockN>>,
  76. Shape<Int<kHeadDim>, Int<kBlockM / AtomLayoutMdQ>, Int<kBlockN>>
  77. >;
  78. using AtomLayoutdQ = std::conditional_t<
  79. !dQ_swapAB,
  80. Layout<Shape<Int<AtomLayoutMdQ>, Int<NumdQWarpGroups / AtomLayoutMdQ>, _1>>,
  81. Layout<Shape<Int<NumdQWarpGroups / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>
  82. >;
  83. static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN;
  84. static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K;
  85. using TiledMmadQ = decltype(cute::make_tiled_mma(
  86. std::conditional_t<
  87. !dQ_swapAB,
  88. std::conditional_t<
  89. Mma_dQ_is_RS,
  90. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>()),
  91. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>())
  92. >,
  93. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::MN, GMMA::Major::K>())
  94. >{},
  95. AtomLayoutdQ{}));
  96. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  97. Int<kBlockM>, Int<dKV_swapAB ? kHeadDim : kHeadDim / (2 / AtomLayoutNdKV)>>());
  98. using SmemLayoutQ =
  99. decltype(tile_to_shape(SmemLayoutAtomQ{},
  100. make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  101. using SmemLayoutdO = SmemLayoutQ;
  102. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  103. Int<kBlockN>, Int<dQ_swapAB ? kHeadDim : kHeadDim / (NumdQWarpGroups / AtomLayoutMdQ)>>());
  104. using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
  105. using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  106. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  107. using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{})));
  108. using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  109. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  110. using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{})));
  111. using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  112. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  113. using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, make_shape(Int<kBlockM>{}, Int<kBlockN>{}, Int<kStages>{})));
  114. // Need stride to be multiple of 32, otherwise we get error (misaligned address) when doing TMA if e.g. kBlockM=80
  115. using SmemLayoutLSE = cute::Layout<cute::Shape<Int<kBlockM>, Int<kStages>>, cute::Stride<_1, Int<cute::round_up(kBlockM, 32)>>>;
  116. using SmemLayoutLSEMma = cute::Layout<cute::Shape<Int<kBlockN>, Int<kBlockM>, Int<kStages>>, cute::Stride<_0, _1, Int<cute::round_up(kBlockM, 32)>>>;
  117. // Note this is the transpose in terms of the view, not in terms of memory.
  118. using SmemLayoutQt =
  119. decltype(cute::composition(SmemLayoutQ{},
  120. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
  121. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
  122. using SmemLayoutdOt =
  123. decltype(cute::composition(SmemLayoutdO{},
  124. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
  125. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
  126. using SmemLayoutKt =
  127. decltype(cute::composition(SmemLayoutK{},
  128. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  129. make_stride(Int<kBlockN>{}, _1{}))));
  130. using SmemLayoutPt =
  131. decltype(cute::composition(SmemLayoutP{},
  132. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  133. make_stride(Int<kBlockM>{}, _1{}))));
  134. using SmemLayoutdSt =
  135. decltype(cute::composition(SmemLayoutdS{},
  136. make_layout(make_shape(Int<kBlockN>{}, Int<kBlockM>{}, Int<kStages>{}),
  137. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kBlockN>{}))));
  138. // Thread layout, 256 threads per row
  139. using R2SLayoutAtomdQaccum = Layout<Shape<Int<kNThreadsdQ>>, Stride<_1>>;
  140. using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{}, R2SLayoutAtomdQaccum{},
  141. Layout<Shape < _4>>{})); // Val layout, 4 vals per store
  142. using SmemLayoutdQaccum = Layout<Shape<Int<kBlockM * kHeadDim>>, Stride<_1>>;
  143. // We want dQaccum smem to have last dimension 32, so that we only need to do 1 TMA instruction.
  144. // The layout Layout_K_SW128_Atom<ElementAccum> has 32 elements per row.
  145. // // TMA limit is that each dimension in smem must be <= 256.
  146. // static constexpr int ElemsPerRowTMA = (kBlockM * kHeadDim) / 32 <= 256 ? 32 : 64;
  147. static constexpr int ElemsPerRowTMA = 32; // If we change this, we'll also need to change the dQ shape in host.
  148. static_assert((kBlockM * kHeadDim) % ElemsPerRowTMA == 0);
  149. using TileShape_dQaccum = cute::Shape<Int<(kBlockM * kHeadDim) / ElemsPerRowTMA>, Int<ElemsPerRowTMA>>;
  150. // using TileShape_dQaccum = cute::Shape<Int<kBlockM>, Int<kHeadDim>>;
  151. using SmemLayoutdQaccumTMA =
  152. decltype(tile_to_shape(GMMA::Layout_K_SW128_Atom<ElementAccum>{}, TileShape_dQaccum{}));
  153. using SmemLayoutdQaccumTMANoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutdQaccumTMA{}));
  154. using SmemCopyAtomPdS = Copy_Atom<
  155. std::conditional_t<!SdP_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  156. Element>;
  157. using SmemCopyAtomdKV = Copy_Atom<
  158. std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  159. Element>;
  160. using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape{})));
  161. using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
  162. using GmemTiledCopydQaccum = cute::SM90_TMA_REDUCE_ADD;
  163. using GmemTiledCopyLSE = cute::SM90_TMA_LOAD;
  164. using ShapeQKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, head, batch)
  165. using StrideQKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
  166. using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen, head, batch)
  167. using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch)
  168. using TMA_QdO = decltype(make_tma_copy(
  169. GmemTiledCopyQdO{},
  170. make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQKV{}),
  171. take<0, 2>(SmemLayoutQ{}),
  172. select<0, 2>(TileShape_MNK{}),
  173. size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
  174. using TMA_K = decltype(make_tma_copy(
  175. GmemTiledCopyKV{},
  176. make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQKV{}),
  177. SmemLayoutK{},
  178. select<1, 2>(TileShape_MNK{}),
  179. _1{})); // no mcast for KV
  180. using TMA_V = decltype(make_tma_copy(
  181. GmemTiledCopyKV{},
  182. make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQKV{}),
  183. SmemLayoutV{},
  184. select<1, 2>(TileShape_MNK{}),
  185. _1{})); // no mcast for KV
  186. using TMA_add_dQ = decltype(make_tma_copy(
  187. GmemTiledCopydQaccum{},
  188. make_tensor(make_gmem_ptr(static_cast<ElementAccum*>(nullptr)), ShapeQKV{}, StrideQKV{}),
  189. SmemLayoutdQaccumTMA{},
  190. TileShape_dQaccum{},
  191. _1{})); // no mcast for dQ
  192. using TMA_LSE = decltype(make_tma_copy(
  193. GmemTiledCopyLSE{},
  194. make_tensor(make_gmem_ptr(static_cast<ElementAccum const*>(nullptr)), ShapeLSE{}, StrideLSE{}),
  195. select<0>(SmemLayoutLSE{}),
  196. select<0>(TileShape_MNK{}),
  197. _1{})); // no mcast for LSE
  198. static constexpr int NumMmaThreads = size(TiledMmaSdP{});
  199. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  200. using PipelineState = typename MainloopPipeline::PipelineState;
  201. // Set the bytes transferred in this TMA transaction (may involve multiple issues)
  202. static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutQ{})) * cutlass::sizeof_bits_v<Element> / 8);
  203. static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(SmemLayoutK{}) * cutlass::sizeof_bits_v<Element> / 8);
  204. static constexpr uint32_t TmaTransactionBytesV = static_cast<uint32_t>(size(SmemLayoutV{}) * cutlass::sizeof_bits_v<Element> / 8);
  205. static constexpr uint32_t TmaTransactionBytesLSE = static_cast<uint32_t>(size(select<0>(SmemLayoutLSE{})) * cutlass::sizeof_bits_v<ElementAccum> / 8);
  206. struct TensorStorage : cute::aligned_struct<1024> {
  207. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  208. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  209. // It's important that smem_dqacc is aligned to 1024 bytes for the TMA, so that the 1st row
  210. // has no swizzle.
  211. // If the address is only 128 bytes aligned, it's possible that the 1st row has swizzle
  212. // and when we read it back in the postprocess kernel, the swizzle will not match.
  213. cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdQaccum>, 1024> smem_dqacc;
  214. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  215. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  216. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  217. cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutLSE>, 128> smem_lse;
  218. cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutLSE>, 128> smem_dpsum;
  219. };
  220. static constexpr int SharedStorageQdOSize = sizeof(decltype((TensorStorage{}).smem_q)) + sizeof(decltype((TensorStorage{}).smem_do)) + sizeof(decltype((TensorStorage{}).smem_ds)) + sizeof(decltype((TensorStorage{}).smem_dqacc));
  221. // Host side kernel arguments
  222. struct Arguments {
  223. Element const* ptr_Q;
  224. ShapeQKV const shape_Q;
  225. StrideQKV const stride_Q;
  226. Element const* ptr_K;
  227. ShapeQKV const shape_K;
  228. StrideQKV const stride_K;
  229. Element const* ptr_V;
  230. StrideQKV const stride_V;
  231. Element const* ptr_dO;
  232. StrideQKV const stride_dO;
  233. ElementAccum* ptr_dQaccum;
  234. ShapeQKV const shape_dQaccum;
  235. StrideQKV const stride_dQaccum;
  236. float const* ptr_LSE_log2;
  237. ShapeLSE const shape_LSE;
  238. StrideLSE const stride_LSE_log2;
  239. float const* ptr_dPsum;
  240. StrideLSE const stride_dPsum;
  241. float const softmax_scale;
  242. int num_batch;
  243. int* dq_semaphore;
  244. int const* cu_seqlens_q = nullptr;
  245. int const* cu_seqlens_k = nullptr;
  246. int const* seqused_k = nullptr;
  247. int const* seqused_v = nullptr;
  248. };
  249. // Device side kernel params
  250. struct Params {
  251. ShapeQKV const shape_Q;
  252. ShapeQKV const shape_K;
  253. ShapeQKV const shape_dQaccum;
  254. cutlass::FastDivmod qhead_per_khead_divmod;
  255. TMA_QdO tma_load_Q, tma_load_dO;
  256. TMA_K tma_load_K;
  257. TMA_V tma_load_V;
  258. TMA_add_dQ tma_add_dQ;
  259. TMA_LSE tma_load_LSE, tma_load_dPsum;
  260. float const* ptr_LSE_log2;
  261. ShapeLSE const shape_LSE;
  262. StrideLSE const stride_LSE_log2;
  263. float const* ptr_dPsum;
  264. StrideLSE const stride_dPsum;
  265. float const softmax_scale;
  266. float const softmax_scale_log2;
  267. int num_batch;
  268. int* dq_semaphore;
  269. int const* cu_seqlens_q = nullptr;
  270. int const* cu_seqlens_k = nullptr;
  271. int const* seqused_q = nullptr;
  272. int const* seqused_k = nullptr;
  273. };
  274. static Params
  275. to_underlying_arguments(Arguments const& args) {
  276. Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q);
  277. TMA_QdO tma_load_Q = make_tma_copy(
  278. GmemTiledCopyQdO{},
  279. mQ,
  280. SmemLayoutQ{}(_, _, _0{}),
  281. select<0, 2>(TileShape_MNK{}),
  282. size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
  283. Tensor mdO = make_tensor(make_gmem_ptr(args.ptr_dO), args.shape_Q, args.stride_dO);
  284. TMA_QdO tma_load_dO = make_tma_copy(
  285. GmemTiledCopyQdO{},
  286. mdO,
  287. SmemLayoutdO{}(_, _, _0{}),
  288. select<0, 2>(TileShape_MNK{}),
  289. size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
  290. Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K);
  291. TMA_K tma_load_K = make_tma_copy(
  292. GmemTiledCopyKV{},
  293. mK,
  294. SmemLayoutK{},
  295. select<1, 2>(TileShape_MNK{}),
  296. _1{}); // no mcast for KV
  297. Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_K, args.stride_V);
  298. TMA_V tma_load_V = make_tma_copy(
  299. GmemTiledCopyKV{},
  300. mV,
  301. SmemLayoutV{},
  302. select<1, 2>(TileShape_MNK{}),
  303. _1{}); // no mcast for KV
  304. Tensor mdQaccum = make_tensor(make_gmem_ptr(args.ptr_dQaccum), args.shape_dQaccum, args.stride_dQaccum);
  305. TMA_add_dQ tma_add_dQ = make_tma_copy(
  306. GmemTiledCopydQaccum{},
  307. mdQaccum,
  308. SmemLayoutdQaccumTMA{},
  309. TileShape_dQaccum{},
  310. _1{}); // no mcast for dQaccum
  311. Tensor mLSE = make_tensor(make_gmem_ptr(args.ptr_LSE_log2), args.shape_LSE, args.stride_LSE_log2);
  312. TMA_LSE tma_load_LSE = make_tma_copy(
  313. GmemTiledCopyLSE{},
  314. mLSE,
  315. select<0>(SmemLayoutLSE{}),
  316. select<0>(TileShape_MNK{}),
  317. _1{}); // no mcast for LSE
  318. Tensor mdPsum = make_tensor(make_gmem_ptr(args.ptr_dPsum), args.shape_LSE, args.stride_dPsum);
  319. TMA_LSE tma_load_dPsum = make_tma_copy(
  320. GmemTiledCopyLSE{},
  321. mdPsum,
  322. select<0>(SmemLayoutLSE{}),
  323. select<0>(TileShape_MNK{}),
  324. _1{}); // no mcast for dPsum
  325. if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); }
  326. return {args.shape_Q, args.shape_K, args.shape_dQaccum,
  327. cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))),
  328. tma_load_Q, tma_load_dO, tma_load_K, tma_load_V, tma_add_dQ, tma_load_LSE, tma_load_dPsum,
  329. args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum,
  330. args.softmax_scale, float(args.softmax_scale * M_LOG2E),
  331. args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k,
  332. args.seqused_k, args.seqused_v};
  333. }
  334. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  335. CUTLASS_DEVICE
  336. static void prefetch_tma_descriptors(Params const& params) {
  337. cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor());
  338. cute::prefetch_tma_descriptor(params.tma_load_dO.get_tma_descriptor());
  339. cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor());
  340. cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor());
  341. cute::prefetch_tma_descriptor(params.tma_load_LSE.get_tma_descriptor());
  342. cute::prefetch_tma_descriptor(params.tma_load_dPsum.get_tma_descriptor());
  343. cute::prefetch_tma_descriptor(params.tma_add_dQ.get_tma_descriptor());
  344. }
  345. CUTLASS_DEVICE
  346. int get_seqlen_q(Params const& params, int bidb) {
  347. if constexpr (!Varlen) {
  348. return get<0>(params.shape_Q);
  349. } else {
  350. return params.cu_seqlens_q == nullptr
  351. ? get<0>(params.shape_Q)
  352. : (params.seqused_q
  353. ? params.seqused_q[bidb]
  354. : params.cu_seqlens_q[bidb + 1] - params.cu_seqlens_q[bidb]
  355. );
  356. }
  357. }
  358. CUTLASS_DEVICE
  359. int get_seqlen_k(Params const& params, int bidb) {
  360. if constexpr (!Varlen) {
  361. return get<0>(params.shape_K);
  362. } else {
  363. return params.cu_seqlens_k == nullptr
  364. ? get<0>(params.shape_K)
  365. : (params.seqused_k
  366. ? params.seqused_k[bidb]
  367. : params.cu_seqlens_k[bidb + 1] - params.cu_seqlens_k[bidb]
  368. );
  369. }
  370. }
  371. CUTLASS_DEVICE
  372. int get_m_block_min(Params const& params, int n_block, int bidb) {
  373. if constexpr (Is_causal) {
  374. int const seqlen_q = get_seqlen_q(params, bidb);
  375. int const seqlen_k = get_seqlen_k(params, bidb);
  376. return std::max(0, (n_block * kBlockN + seqlen_q - seqlen_k) / kBlockM);
  377. } else {
  378. return 0;
  379. }
  380. }
  381. template <typename SchedulerPrefetch, typename SharedStorage>
  382. CUTLASS_DEVICE void
  383. load(Params const& params,
  384. MainloopPipeline pipeline_q,
  385. MainloopPipeline pipeline_do,
  386. PipelineState& smem_pipe_write,
  387. SharedStorage &shared_storage,
  388. SchedulerPrefetch const& scheduler_prefetch,
  389. cute::tuple<int32_t, int32_t, int32_t> block_coord,
  390. int work_idx
  391. ) {
  392. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_q.data()), SmemLayoutQ{});
  393. Tensor sdO = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_do.data()), SmemLayoutdO{});
  394. Tensor sK = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_k.data()), SmemLayoutK{});
  395. Tensor sV = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_v.data()), SmemLayoutV{});
  396. Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_lse.data()), SmemLayoutLSE{});
  397. Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_dpsum.data()), SmemLayoutLSE{});
  398. auto [n_block, bidh, bidb] = block_coord;
  399. int bidh_kv = params.qhead_per_khead_divmod.divide(bidh);
  400. // Prepare the TMA loads
  401. uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
  402. constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
  403. uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
  404. bool const is_varlen_q = Varlen && params.cu_seqlens_q != nullptr;
  405. bool const is_varlen_k = Varlen && params.cu_seqlens_k != nullptr;
  406. Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0);
  407. Tensor mdO = params.tma_load_dO.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0);
  408. Tensor mK = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0);
  409. Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0);
  410. Tensor mLSE = params.tma_load_LSE.get_tma_tensor(params.shape_LSE)(_, bidh, !is_varlen_q ? bidb : 0);
  411. Tensor mdPsum = params.tma_load_dPsum.get_tma_tensor(params.shape_LSE)(_, bidh, !is_varlen_q ? bidb : 0);
  412. int const offset_q = !is_varlen_q ? 0 : params.cu_seqlens_q[bidb];
  413. int const offset_k = !is_varlen_k ? 0 : params.cu_seqlens_k[bidb];
  414. int const offset_padded = !is_varlen_q ? 0 : (params.cu_seqlens_q[bidb] + bidb * 128) / 128 * 128;
  415. Tensor gQ = local_tile(domain_offset(make_coord(offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
  416. Tensor gdO = local_tile(domain_offset(make_coord(offset_q, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
  417. Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K)
  418. Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K)
  419. Tensor gLSE = local_tile(domain_offset(make_coord(offset_padded), mLSE), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _)
  420. Tensor gdPsum = local_tile(domain_offset(make_coord(offset_padded), mdPsum), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _)
  421. Tensor sK_x = make_tensor(sK.data(), make_layout(sK.layout(), Layout<_1>{}));
  422. Tensor gK_x = make_tensor(gK.data(), make_layout(gK.layout(), Layout<_1>{}));
  423. Tensor sV_x = make_tensor(sV.data(), make_layout(sV.layout(), Layout<_1>{}));
  424. Tensor gV_x = make_tensor(gV.data(), make_layout(gV.layout(), Layout<_1>{}));
  425. auto [tQgQ, tQsQ] = tma_partition(params.tma_load_Q, block_rank_in_cluster, Layout<ClusterShape>{},
  426. group_modes<0, 2>(sQ), group_modes<0, 2>(gQ)); // (TMA, k), (TMA, PIPE)
  427. auto [tdOgdO, tdOsdO] = tma_partition(params.tma_load_dO, block_rank_in_cluster, Layout<ClusterShape>{},
  428. group_modes<0, 2>(sdO), group_modes<0, 2>(gdO)); // (TMA, k), (TMA, PIPE)
  429. auto [tKgK, tKsK] = tma_partition(params.tma_load_K, _0{}, Layout<_1>{},
  430. group_modes<0, 2>(sK_x), group_modes<0, 2>(gK_x)); // (TMA), (TMA)
  431. auto [tVgV, tVsV] = tma_partition(params.tma_load_V, _0{}, Layout<_1>{},
  432. group_modes<0, 2>(sV_x), group_modes<0, 2>(gV_x)); // (TMA), (TMA)
  433. auto [tLSEgLSE, tLSEsLSE] = tma_partition(params.tma_load_LSE, _0{}, Layout<_1>{},
  434. sLSE, gLSE); // (TMA, k), (TMA, PIPE)
  435. auto [tLSEgdPsum, tLSEsdPsum] = tma_partition(params.tma_load_dPsum, _0{}, Layout<_1>{},
  436. sdPsum, gdPsum); // (TMA, k), (TMA, PIPE)
  437. uint16_t mcast_mask_qdo = 0;
  438. if constexpr (cute::is_same_v<GmemTiledCopyQdO, SM90_TMA_LOAD_MULTICAST>) {
  439. auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
  440. for (int n = 0; n < size<1>(block_layout); ++n) {
  441. mcast_mask_qdo |= (uint16_t(1) << block_layout(n, cluster_local_block_id.x, _0{}));
  442. }
  443. }
  444. int m_block_max = cute::ceil_div(get_seqlen_q(params, bidb), get<0>(TileShape_MNK{}));
  445. int m_block_min = get_m_block_min(params, n_block, bidb);
  446. int m_block = m_block_min;
  447. int lane_predicate = cute::elect_one_sync();
  448. // // Wait for the MMA warpgroups to say that smem_q is ready
  449. // cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::QueryEmpty) /*id*/);
  450. if (lane_predicate) {
  451. // Copy K tile and V tile from GMEM to SMEM.
  452. shared_storage.barrier_KV.arrive_and_expect_tx(TmaTransactionBytesK + TmaTransactionBytesV);
  453. copy(params.tma_load_K.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_KV), 0 /*mcast_mask*/), tKgK, tKsK);
  454. copy(params.tma_load_V.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_KV), 0 /*mcast_mask*/), tVgV, tVsV);
  455. pipeline_q.producer_acquire(smem_pipe_write);
  456. copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo), tQgQ(_, m_block), tQsQ(_, smem_pipe_write.index()));
  457. copy(params.tma_load_LSE.with(*pipeline_q.producer_get_barrier(smem_pipe_write), 0), tLSEgLSE(_, m_block), tLSEsLSE(_, smem_pipe_write.index()));
  458. #pragma unroll 2
  459. for (; m_block < m_block_max - 1; ++m_block) {
  460. pipeline_do.producer_acquire(smem_pipe_write);
  461. copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write), mcast_mask_qdo), tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write.index()));
  462. copy(params.tma_load_dPsum.with(*pipeline_do.producer_get_barrier(smem_pipe_write), 0), tLSEgdPsum(_, m_block), tLSEsdPsum(_, smem_pipe_write.index()));
  463. ++smem_pipe_write;
  464. pipeline_q.producer_acquire(smem_pipe_write);
  465. copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo), tQgQ(_, m_block + 1), tQsQ(_, smem_pipe_write.index()));
  466. copy(params.tma_load_LSE.with(*pipeline_q.producer_get_barrier(smem_pipe_write), 0), tLSEgLSE(_, m_block + 1), tLSEsLSE(_, smem_pipe_write.index()));
  467. }
  468. }
  469. scheduler_prefetch();
  470. if (lane_predicate) {
  471. pipeline_do.producer_acquire(smem_pipe_write);
  472. copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write), mcast_mask_qdo), tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write.index()));
  473. copy(params.tma_load_dPsum.with(*pipeline_do.producer_get_barrier(smem_pipe_write), 0), tLSEgdPsum(_, m_block), tLSEsdPsum(_, smem_pipe_write.index()));
  474. ++smem_pipe_write;
  475. }
  476. }
  477. /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
  478. CUTLASS_DEVICE void
  479. load_tail(MainloopPipeline pipeline_q, MainloopPipeline pipeline_do,
  480. PipelineState& smem_pipe_write) {
  481. // Need to copy since pipeline_q.producer_tail(smem_pipe_write) will increment smem_pipe_write
  482. PipelineState smem_pipe_write_do = smem_pipe_write;
  483. int lane_predicate = cute::elect_one_sync();
  484. // Issue the epilogue waits
  485. if (lane_predicate) {
  486. /* This helps avoid early exit of blocks in Cluster
  487. * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
  488. * then would just be acquired since the phase was still inverted from make_producer_start_state
  489. */
  490. pipeline_q.producer_tail(smem_pipe_write);
  491. pipeline_do.producer_tail(smem_pipe_write_do);
  492. }
  493. }
  494. template <typename SharedStorage>
  495. CUTLASS_DEVICE void
  496. store_dq(Params const& params,
  497. SharedStorage &shared_storage,
  498. cute::tuple<int32_t, int32_t, int32_t> block_coord
  499. ) {
  500. Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_dqacc.data()), SmemLayoutdQaccumTMA{});
  501. Tensor sdQnoswizzle = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_dqacc.data()), SmemLayoutdQaccumTMANoSwizzle{});
  502. auto [n_block, bidh, bidb] = block_coord;
  503. bool const is_varlen_q = Varlen && params.cu_seqlens_q != nullptr;
  504. // We reshaped dQaccum to have last dimension 32, so the offset needs to be multiplied by kHeadDim / 32
  505. int const offset_padded = !is_varlen_q ? 0 : ((params.cu_seqlens_q[bidb] + bidb * 128) / 128 * 128) * (kHeadDim / ElemsPerRowTMA);
  506. // Prepare the TMA loads
  507. Tensor mdQaccum = params.tma_add_dQ.get_tma_tensor(params.shape_dQaccum)(_, _, bidh, !is_varlen_q ? bidb : 0);
  508. Tensor gdQaccum = local_tile(domain_offset(make_coord(offset_padded, _0{}), mdQaccum), TileShape_dQaccum{}, make_coord(_, _0{})); // (M, K, _)
  509. auto block_tma_dQ = params.tma_add_dQ.get_slice(_0{});
  510. Tensor tdQgdQ = block_tma_dQ.partition_D(gdQaccum); // (TMA, TMA_M, TMA_K)
  511. Tensor tdQsdQ = block_tma_dQ.partition_S(sdQ); // (TMA, TMA_M, TMA_K)
  512. int m_block_max = cute::ceil_div(get_seqlen_q(params, bidb), get<0>(TileShape_MNK{}));
  513. int m_block_min = get_m_block_min(params, n_block, bidb);
  514. int m_block = m_block_min;
  515. int const num_batch = params.num_batch;
  516. int const num_head = get<2>(params.shape_Q);
  517. int *lock_ptr = !Deterministic ? nullptr : params.dq_semaphore + bidb * num_head + bidh;
  518. using Barrier = cutlass::GenericBarrier<cutlass::detail::SyncwarpSync>;
  519. int lane_predicate = cute::elect_one_sync();
  520. #pragma unroll 2
  521. for (; m_block < m_block_max; ++m_block) {
  522. if constexpr (Deterministic) {
  523. Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block);
  524. }
  525. cutlass::arch::NamedBarrier::sync(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQFull) /*id*/); // sdQ full, to be written to gmem
  526. if (lane_predicate) {
  527. cute::copy(params.tma_add_dQ, tdQsdQ, tdQgdQ(_, _, _, m_block));
  528. tma_store_arrive();
  529. }
  530. tma_store_wait<0>();
  531. if constexpr (Deterministic) {
  532. Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head);
  533. }
  534. cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQEmpty) /*id*/); // sdQ empty, ready to be written to
  535. }
  536. }
  537. CUTLASS_DEVICE void
  538. mma_init() {
  539. // // Tell producer (warp 0) that smem_q is ready
  540. // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::QueryEmpty) /*id*/);
  541. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  542. if (cutlass::canonical_warp_group_idx() == 1 && warp_idx_in_warpgroup == 0) {
  543. cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQEmpty) /*id*/); // sdQ empty, ready to be written to
  544. }
  545. }
  546. template <typename SharedStorage, typename FrgTensordKV>
  547. CUTLASS_DEVICE void
  548. mma(Params const& params,
  549. MainloopPipeline pipeline_q,
  550. MainloopPipeline pipeline_do,
  551. PipelineState& smem_pipe_read,
  552. FrgTensordKV& tdKrdK,
  553. FrgTensordKV& tdVrdV,
  554. int thread_idx,
  555. int work_idx,
  556. cute::tuple<int32_t, int32_t, int32_t> block_coord,
  557. SharedStorage& shared_storage
  558. ) {
  559. static_assert(is_rmem<FrgTensordKV>::value, "dK and dV tensor must be rmem resident.");
  560. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_q.data()), SmemLayoutQ{});
  561. Tensor sdO = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_do.data()), SmemLayoutdO{});
  562. Tensor sK = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_k.data()), SmemLayoutK{});
  563. Tensor sV = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_v.data()), SmemLayoutV{});
  564. Tensor sQt = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_q.data()), SmemLayoutQt{});
  565. Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_do.data()), SmemLayoutdOt{});
  566. Tensor sKt = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_k.data()), SmemLayoutKt{});
  567. Tensor sdS = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_ds.data()), SmemLayoutdS{});
  568. Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_ds.data()), SmemLayoutdSt{});
  569. Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_dqacc.data()), SmemLayoutdQaccum{});
  570. Tensor sLSEMma = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_lse.data()), SmemLayoutLSEMma{});
  571. Tensor sdPsumMma = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_dpsum.data()), SmemLayoutLSEMma{});
  572. static_assert(stride<0>(typename TiledMmaSdP::ALayout{}) == 0 and
  573. stride<0>(typename TiledMmaSdP::BLayout{}) == 0 and
  574. size<0>(typename TiledMmaSdP::ALayout{}) == cutlass::NumThreadsPerWarpGroup and
  575. size<0>(typename TiledMmaSdP::BLayout{}) == cutlass::NumThreadsPerWarpGroup,
  576. "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
  577. constexpr int MmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup;
  578. Layout warp_group_thread_layout = make_layout(make_shape(Int<MmaWarpGroups>{}),
  579. make_stride(Int<cutlass::NumThreadsPerWarpGroup>{}));
  580. Layout warp_group_thread_layout_dq = make_layout(make_shape(Int<NumdQWarpGroups>{}),
  581. make_stride(Int<cutlass::NumThreadsPerWarpGroup>{}));
  582. int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0);
  583. TiledMmaSdP tiled_mma_SdP;
  584. TiledMmadKV tiled_mma_dKV;
  585. TiledMmadQ tiled_mma_dQ;
  586. static_assert(!dKV_swapAB);
  587. auto wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx));
  588. auto thread_mma_SdP = tiled_mma_SdP.get_thread_slice(thread_idx);
  589. auto wg_mma_dKV = tiled_mma_dKV.get_slice(warp_group_thread_layout(warp_group_idx));
  590. auto wg_mma_dQ = tiled_mma_dQ.get_slice(!Varlen ? warp_group_thread_layout_dq(NumdQWarpGroups == 2 ? warp_group_idx : 0) : thread_idx);
  591. // auto wg_mma_dQ = tiled_mma_dQ.get_thread_slice(thread_idx);
  592. auto smem_tiled_copy_PdS = make_tiled_copy_C(SmemCopyAtomPdS{}, tiled_mma_SdP);
  593. auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(thread_idx);
  594. Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdSt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  595. R2STiledCopydQaccum r2s_tiled_copy_dQaccum;
  596. // auto r2s_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_thread_slice(thread_idx);
  597. auto r2s_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_thread_slice(NumdQWarpGroups == 2 ? thread_idx : thread_idx % cutlass::NumThreadsPerWarpGroup);
  598. Tensor tdQsdQaccum = r2s_thr_copy_dQaccum.partition_D(sdQ);
  599. // Allocate "fragments/descriptors"
  600. Tensor tSrQ = wg_mma_SdP.partition_fragment_B(sQ);
  601. Tensor tSrK = wg_mma_SdP.partition_fragment_A(sK);
  602. Tensor tdPrdO = wg_mma_SdP.partition_fragment_B(sdO);
  603. Tensor tdPrV = wg_mma_SdP.partition_fragment_A(sV);
  604. Tensor tdVrdO = wg_mma_dKV.partition_fragment_B(sdOt);
  605. Tensor tdKrQ = wg_mma_dKV.partition_fragment_B(sQt);
  606. int n_block = get<0>(block_coord);
  607. int bidh = get<1>(block_coord);
  608. int bidb = get<2>(block_coord);
  609. int const seqlen_q = get_seqlen_q(params, bidb);
  610. int const seqlen_k = get_seqlen_k(params, bidb);
  611. int m_block_max = cute::ceil_div(get_seqlen_q(params, bidb), get<0>(TileShape_MNK{}));
  612. int m_block_min = get_m_block_min(params, n_block, bidb);
  613. int m_block = m_block_min;
  614. // thread_mma_SdP.partition_C(sLSEMma) has shape ((2, 2, V), MMA_M, MMA_N, PIPE), we only take the row indices.
  615. Tensor tLSEsLSE = thread_mma_SdP.partition_C(sLSEMma)(make_coord(_, _0{}, _), _0{}, _0{}, _); // (2, V, PIPE)
  616. Tensor tLSEsdPsum = thread_mma_SdP.partition_C(sdPsumMma)(make_coord(_, _0{}, _), _0{}, _0{}, _);
  617. clear(tdKrdK);
  618. clear(tdVrdV);
  619. // tiled_mma_dKV.accumulate_ = GMMA::ScaleOut::Zero;
  620. cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_KV.try_wait(work_idx % 2));
  621. if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_KV.wait(work_idx % 2); }
  622. auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
  623. auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
  624. pipeline.consumer_wait(smem_pipe_read, barrier_token);
  625. };
  626. auto compute_dQ = [&]() {
  627. static_assert(!Mma_dQ_is_RS);
  628. // SMEM fence to make sure sP is written before it's read by WGMMA
  629. cutlass::arch::fence_view_async_shared();
  630. cutlass::arch::NamedBarrier::sync(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQEmpty) /*id*/); // sdQ empty, ready to be written to
  631. Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
  632. if constexpr (!dQ_swapAB) {
  633. Tensor tdQrdS = wg_mma_dQ.partition_fragment_A(sdS);
  634. Tensor tdQrK = wg_mma_dQ.partition_fragment_B(sKt);
  635. flash::gemm</*zero_init=*/true, /*wg_wait=*/1>(tiled_mma_dQ, tdQrdS(_, _, _, smem_pipe_read.index()), tdQrK, tdQrdQ);
  636. } else {
  637. Tensor tdQrdS = wg_mma_dQ.partition_fragment_B(sdS);
  638. Tensor tdQrK = wg_mma_dQ.partition_fragment_A(sKt);
  639. flash::gemm</*zero_init=*/true, /*wg_wait=*/1>(tiled_mma_dQ, tdQrK, tdQrdS(_, _, _, smem_pipe_read.index()), tdQrdQ);
  640. }
  641. pipeline_q.consumer_release(smem_pipe_read); // release Q
  642. warpgroup_wait<0>();
  643. Tensor taccdQrdQ = r2s_thr_copy_dQaccum.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N)
  644. cute::copy(r2s_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum);
  645. cutlass::arch::fence_view_async_shared();
  646. cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQFull) /*id*/); // sdQ full, to be written to gmem
  647. };
  648. // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64
  649. // this helps quite a bit to not have to do causal masking for most of the iterations.
  650. if constexpr (Is_causal) {
  651. static constexpr int n_masking_steps = cute::ceil_div(kBlockN, kBlockM) + 1;
  652. CUTLASS_PRAGMA_NO_UNROLL
  653. for (; m_block < std::min(m_block_max, m_block_min + n_masking_steps); ++m_block) {
  654. Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select<1, 0>(TileShape_MNK{}));
  655. pipeline_q.consumer_wait(smem_pipe_read);
  656. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma_SdP, tSrK, tSrQ(_, _, _, smem_pipe_read.index()), tSrS);
  657. Tensor tLSErLSE = make_fragment_like(tLSEsLSE(_, _, _0{}));
  658. cute::copy(tLSEsLSE(_, _, smem_pipe_read.index()), tLSErLSE);
  659. Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select<1, 0>(TileShape_MNK{}));
  660. pipeline_do.consumer_wait(smem_pipe_read);
  661. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma_SdP, tdPrV, tdPrdO(_, _, _, smem_pipe_read.index()), tdPrdP);
  662. warpgroup_wait<1>();
  663. Tensor cS = cute::make_identity_tensor(select<1, 0>(TileShape_MNK{}));
  664. Tensor taccScS = thread_mma_SdP.partition_C(cS);
  665. int causal_row_offset = 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM;
  666. #pragma unroll
  667. for (int i = 0; i < size(tSrS); ++i) {
  668. if (int(get<0>(taccScS(i))) >= std::min(int(get<1>(taccScS(i))) + causal_row_offset,
  669. seqlen_k - n_block * kBlockN)) {
  670. tSrS(i) = -INFINITY;
  671. }
  672. }
  673. // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
  674. Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout()));
  675. flash::scale_apply_exp2</*Scale_max=*/false, /*Check_inf=*/false>(scores, group_modes<0, 2>(tLSErLSE), params.softmax_scale_log2);
  676. Tensor tLSErdPsum = make_fragment_like(tLSEsdPsum(_, _, _0{}));
  677. cute::copy(tLSEsdPsum(_, _, smem_pipe_read.index()), tLSErdPsum);
  678. // Convert scores from fp32 to fp16/bf16
  679. Tensor rP = flash::convert_type<Element>(tSrS);
  680. warpgroup_wait<0>();
  681. // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
  682. Tensor dS = make_tensor(tdPrdP.data(), scores.layout());
  683. for (int mi = 0; mi < size<0>(dS); ++mi) {
  684. #pragma unroll
  685. for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - tLSErdPsum(mi)); }
  686. }
  687. Tensor rdS = flash::convert_type<Element>(tdPrdP);
  688. // Because of double buffering on dS, we don't need to sync here.
  689. // Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ.
  690. // But because both WGs have to sync at the end of the loop and double buffering, this race condition
  691. // is not possible.
  692. Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N)
  693. cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS(_, _, _, smem_pipe_read.index()));
  694. Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs<TiledMmadKV>(tSrS.layout()));
  695. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma_dKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read.index()), tdVrdV);
  696. Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<TiledMmadKV>(tdPrdP.layout()));
  697. flash::gemm</*zero_init=*/false, /*wg_wait=*/1>(tiled_mma_dKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK);
  698. pipeline_do.consumer_release(smem_pipe_read); // release dO
  699. compute_dQ();
  700. ++smem_pipe_read;
  701. }
  702. }
  703. CUTLASS_PRAGMA_NO_UNROLL
  704. for (; m_block < m_block_max; ++m_block) {
  705. Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select<1, 0>(TileShape_MNK{}));
  706. pipeline_q.consumer_wait(smem_pipe_read);
  707. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma_SdP, tSrK, tSrQ(_, _, _, smem_pipe_read.index()), tSrS);
  708. Tensor tLSErLSE = make_fragment_like(tLSEsLSE(_, _, _0{}));
  709. cute::copy(tLSEsLSE(_, _, smem_pipe_read.index()), tLSErLSE);
  710. Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select<1, 0>(TileShape_MNK{}));
  711. pipeline_do.consumer_wait(smem_pipe_read);
  712. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma_SdP, tdPrV, tdPrdO(_, _, _, smem_pipe_read.index()), tdPrdP);
  713. warpgroup_wait<1>();
  714. Tensor cS = cute::make_identity_tensor(select<1, 0>(TileShape_MNK{}));
  715. Tensor taccScS = thread_mma_SdP.partition_C(cS);
  716. #pragma unroll
  717. for (int i = 0; i < size(tSrS); ++i) {
  718. if (int(get<0>(taccScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; }
  719. }
  720. // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
  721. Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout()));
  722. // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tLSErLSE); }
  723. // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(scores); }
  724. flash::scale_apply_exp2</*Scale_max=*/false, /*Check_inf=*/false>(scores, group_modes<0, 2>(tLSErLSE), params.softmax_scale_log2);
  725. // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(scores); }
  726. Tensor tLSErdPsum = make_fragment_like(tLSEsdPsum(_, _, _0{}));
  727. cute::copy(tLSEsdPsum(_, _, smem_pipe_read.index()), tLSErdPsum);
  728. // Convert scores from fp32 to fp16/bf16
  729. Tensor rP = flash::convert_type<Element>(tSrS);
  730. warpgroup_wait<0>();
  731. // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
  732. Tensor dS = make_tensor(tdPrdP.data(), scores.layout());
  733. #pragma unroll
  734. for (int mi = 0; mi < size<0>(dS); ++mi) {
  735. #pragma unroll
  736. for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - tLSErdPsum(mi)); }
  737. }
  738. // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dS); }
  739. Tensor rdS = flash::convert_type<Element>(tdPrdP);
  740. Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N)
  741. cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS(_, _, _, smem_pipe_read.index()));
  742. Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs<TiledMmadKV>(tSrS.layout()));
  743. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma_dKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read.index()), tdVrdV);
  744. Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<TiledMmadKV>(tdPrdP.layout()));
  745. flash::gemm</*zero_init=*/false, /*wg_wait=*/1>(tiled_mma_dKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK);
  746. pipeline_do.consumer_release(smem_pipe_read); // release dO
  747. compute_dQ();
  748. ++smem_pipe_read;
  749. }
  750. // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tdVrdV); }
  751. #pragma unroll
  752. for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.softmax_scale; }
  753. }
  754. };
  755. } // namespace flash