mainloop_bwd_sm90_tma_gmma_ws.hpp 50 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cutlass/cutlass.h>
  6. #include <cutlass/array.h>
  7. #include <cutlass/numeric_types.h>
  8. #include <cutlass/numeric_conversion.h>
  9. #include <cutlass/barrier.h>
  10. #include "cutlass/pipeline/pipeline.hpp"
  11. #include "cute/tensor.hpp"
  12. #include "cutlass/gemm/collective/collective_builder.hpp"
  13. #include "named_barrier.hpp"
  14. #include "softmax.h"
  15. #include "utils.h"
  16. namespace flash {
  17. using namespace cute;
  18. template <int Stages, class ClusterShape_, class TileShape_MNK_, class Element_, class ElementAccum_, class ArchTag_,
  19. bool Is_causal_, bool Is_local_, bool Varlen_, bool Deterministic,
  20. bool dKV_swapAB_, bool dQ_swapAB_,
  21. int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1>
  22. struct CollectiveMainloopBwd {
  23. static constexpr int kStages = Stages;
  24. using ClusterShape = ClusterShape_;
  25. using TileShape_MNK = TileShape_MNK_;
  26. using Element = Element_;
  27. using ElementAccum = ElementAccum_;
  28. using ArchTag = ArchTag_;
  29. static constexpr bool Is_causal = Is_causal_;
  30. static constexpr bool Is_local = Is_local_;
  31. static constexpr bool Varlen = Varlen_;
  32. static constexpr bool SdP_swapAB = true;
  33. static constexpr bool dKV_swapAB = dKV_swapAB_;
  34. static constexpr bool dQ_swapAB = dQ_swapAB_;
  35. static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV
  36. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  37. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  38. static constexpr int kHeadDim = get<2>(TileShape_MNK{});
  39. static constexpr int NumdQWarpGroups = 2;
  40. static constexpr int kNThreadsdQ = NumdQWarpGroups * cutlass::NumThreadsPerWarpGroup;
  41. static_assert(ArchTag::kMinComputeCapability >= 90);
  42. static_assert(get<0>(ClusterShape{}) == 1 && get<2>(ClusterShape{}) == 1);
  43. static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS
  44. using TileShapeAtomSdP = std::conditional_t<
  45. !SdP_swapAB,
  46. Shape<Int<kBlockM>, Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kHeadDim>>,
  47. Shape<Int<kBlockN>, Int<kBlockM / AtomLayoutMSdP>, Int<kHeadDim>>
  48. >;
  49. using AtomLayoutSdP = std::conditional_t<
  50. !SdP_swapAB,
  51. Layout<Shape<Int<AtomLayoutMSdP>, Int<2 / AtomLayoutMSdP>, _1>>,
  52. Layout<Shape<Int<2 / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>
  53. >;
  54. using TiledMmaSdP = decltype(cute::make_tiled_mma(
  55. cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
  56. AtomLayoutSdP{}));
  57. using TileShapeAtomdKV = std::conditional_t<
  58. !dKV_swapAB,
  59. Shape<Int<kBlockN>, Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockM>>,
  60. Shape<Int<kHeadDim>, Int<kBlockN / AtomLayoutNdKV>, Int<kBlockM>>
  61. >;
  62. using AtomLayoutdKV = std::conditional_t<
  63. !dKV_swapAB,
  64. Layout<Shape<Int<AtomLayoutNdKV>, Int<2 / AtomLayoutNdKV>, _1>>,
  65. Layout<Shape<Int<2 / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>
  66. >;
  67. using TiledMmadKV = decltype(cute::make_tiled_mma(
  68. std::conditional_t<
  69. !SdP_swapAB,
  70. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::MN, GMMA::Major::MN>()),
  71. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::K, GMMA::Major::MN>())
  72. >{},
  73. AtomLayoutdKV{}));
  74. using TileShapeAtomdQ = std::conditional_t<
  75. !dQ_swapAB,
  76. Shape<Int<kBlockM>, Int<kHeadDim / (NumdQWarpGroups / AtomLayoutMdQ)>, Int<kBlockN>>,
  77. Shape<Int<kHeadDim>, Int<kBlockM / AtomLayoutMdQ>, Int<kBlockN>>
  78. >;
  79. using AtomLayoutdQ = std::conditional_t<
  80. !dQ_swapAB,
  81. Layout<Shape<Int<AtomLayoutMdQ>, Int<NumdQWarpGroups / AtomLayoutMdQ>, _1>>,
  82. Layout<Shape<Int<NumdQWarpGroups / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>
  83. >;
  84. static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN;
  85. static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K;
  86. using TiledMmadQ = decltype(cute::make_tiled_mma(
  87. std::conditional_t<
  88. !dQ_swapAB,
  89. std::conditional_t<
  90. Mma_dQ_is_RS,
  91. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>()),
  92. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>())
  93. >,
  94. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::MN, GMMA::Major::K>())
  95. >{},
  96. AtomLayoutdQ{}));
  97. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  98. Int<kBlockM>, Int<dKV_swapAB ? kHeadDim : kHeadDim / (2 / AtomLayoutNdKV)>>());
  99. using SmemLayoutQ =
  100. decltype(tile_to_shape(SmemLayoutAtomQ{},
  101. make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  102. using SmemLayoutdO = SmemLayoutQ;
  103. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  104. Int<kBlockN>, Int<dQ_swapAB ? kHeadDim : kHeadDim / (NumdQWarpGroups / AtomLayoutMdQ)>>());
  105. using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
  106. using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  107. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  108. using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{})));
  109. using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  110. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  111. using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{})));
  112. using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  113. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  114. using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, make_shape(Int<kBlockM>{}, Int<kBlockN>{}, Int<kStages>{})));
  115. // Need stride to be multiple of 32, otherwise we get error (misaligned address) when doing TMA if e.g. kBlockM=80
  116. using SmemLayoutLSE = cute::Layout<cute::Shape<Int<kBlockM>, Int<kStages>>, cute::Stride<_1, Int<cute::round_up(kBlockM, 32)>>>;
  117. using SmemLayoutLSEMma = cute::Layout<cute::Shape<Int<kBlockN>, Int<kBlockM>, Int<kStages>>, cute::Stride<_0, _1, Int<cute::round_up(kBlockM, 32)>>>;
  118. // Note this is the transpose in terms of the view, not in terms of memory.
  119. using SmemLayoutQt =
  120. decltype(cute::composition(SmemLayoutQ{},
  121. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
  122. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
  123. using SmemLayoutdOt =
  124. decltype(cute::composition(SmemLayoutdO{},
  125. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
  126. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
  127. using SmemLayoutKt =
  128. decltype(cute::composition(SmemLayoutK{},
  129. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  130. make_stride(Int<kBlockN>{}, _1{}))));
  131. using SmemLayoutPt =
  132. decltype(cute::composition(SmemLayoutP{},
  133. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  134. make_stride(Int<kBlockM>{}, _1{}))));
  135. using SmemLayoutdSt =
  136. decltype(cute::composition(SmemLayoutdS{},
  137. make_layout(make_shape(Int<kBlockN>{}, Int<kBlockM>{}, Int<kStages>{}),
  138. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kBlockN>{}))));
  139. // Thread layout, 256 threads per row
  140. using R2SLayoutAtomdQaccum = Layout<Shape<Int<kNThreadsdQ>>, Stride<_1>>;
  141. using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{}, R2SLayoutAtomdQaccum{},
  142. Layout<Shape < _4>>{})); // Val layout, 4 vals per store
  143. using SmemLayoutdQaccum = Layout<Shape<Int<kBlockM * kHeadDim>>, Stride<_1>>;
  144. // We want dQaccum smem to have last dimension 32, so that we only need to do 1 TMA instruction.
  145. // The layout Layout_K_SW128_Atom<ElementAccum> has 32 elements per row.
  146. // // TMA limit is that each dimension in smem must be <= 256.
  147. // static constexpr int ElemsPerRowTMA = (kBlockM * kHeadDim) / 32 <= 256 ? 32 : 64;
  148. static constexpr int ElemsPerRowTMA = 32; // If we change this, we'll also need to change the dQ shape in host.
  149. static_assert((kBlockM * kHeadDim) % ElemsPerRowTMA == 0);
  150. using TileShape_dQaccum = cute::Shape<Int<(kBlockM * kHeadDim) / ElemsPerRowTMA>, Int<ElemsPerRowTMA>>;
  151. // using TileShape_dQaccum = cute::Shape<Int<kBlockM>, Int<kHeadDim>>;
  152. using SmemLayoutdQaccumTMA =
  153. decltype(tile_to_shape(GMMA::Layout_K_SW128_Atom<ElementAccum>{}, TileShape_dQaccum{}));
  154. using SmemLayoutdQaccumTMANoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutdQaccumTMA{}));
  155. using SmemCopyAtomPdS = Copy_Atom<
  156. std::conditional_t<!SdP_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  157. Element>;
  158. using SmemCopyAtomdKV = Copy_Atom<
  159. std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  160. Element>;
  161. using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape{})));
  162. using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
  163. using GmemTiledCopydQaccum = cute::SM90_TMA_REDUCE_ADD;
  164. using GmemTiledCopyLSE = cute::SM90_TMA_LOAD;
  165. using ShapeQKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, head, batch)
  166. using StrideQKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
  167. using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen, head, batch)
  168. using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch)
  169. using TMA_QdO = decltype(make_tma_copy(
  170. GmemTiledCopyQdO{},
  171. make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQKV{}),
  172. take<0, 2>(SmemLayoutQ{}),
  173. select<0, 2>(TileShape_MNK{}),
  174. size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
  175. using TMA_K = decltype(make_tma_copy(
  176. GmemTiledCopyKV{},
  177. make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQKV{}),
  178. SmemLayoutK{},
  179. select<1, 2>(TileShape_MNK{}),
  180. _1{})); // no mcast for KV
  181. using TMA_V = decltype(make_tma_copy(
  182. GmemTiledCopyKV{},
  183. make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), ShapeQKV{}, StrideQKV{}),
  184. SmemLayoutV{},
  185. select<1, 2>(TileShape_MNK{}),
  186. _1{})); // no mcast for KV
  187. using TMA_add_dQ = decltype(make_tma_copy(
  188. GmemTiledCopydQaccum{},
  189. make_tensor(make_gmem_ptr(static_cast<ElementAccum*>(nullptr)), ShapeQKV{}, StrideQKV{}),
  190. SmemLayoutdQaccumTMA{},
  191. TileShape_dQaccum{},
  192. _1{})); // no mcast for dQ
  193. using TMA_LSE = decltype(make_tma_copy(
  194. GmemTiledCopyLSE{},
  195. make_tensor(make_gmem_ptr(static_cast<ElementAccum const*>(nullptr)), ShapeLSE{}, StrideLSE{}),
  196. select<0>(SmemLayoutLSE{}),
  197. select<0>(TileShape_MNK{}),
  198. _1{})); // no mcast for LSE
  199. static constexpr int NumMmaThreads = size(TiledMmaSdP{});
  200. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  201. using PipelineState = typename MainloopPipeline::PipelineState;
  202. // Set the bytes transferred in this TMA transaction (may involve multiple issues)
  203. static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutQ{})) * cutlass::sizeof_bits_v<Element> / 8);
  204. static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(SmemLayoutK{}) * cutlass::sizeof_bits_v<Element> / 8);
  205. static constexpr uint32_t TmaTransactionBytesV = static_cast<uint32_t>(size(SmemLayoutV{}) * cutlass::sizeof_bits_v<Element> / 8);
  206. static constexpr uint32_t TmaTransactionBytesLSE = static_cast<uint32_t>(size(select<0>(SmemLayoutLSE{})) * cutlass::sizeof_bits_v<ElementAccum> / 8);
  207. struct TensorStorage : cute::aligned_struct<1024> {
  208. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  209. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  210. // It's important that smem_dqacc is aligned to 1024 bytes for the TMA, so that the 1st row
  211. // has no swizzle.
  212. // If the address is only 128 bytes aligned, it's possible that the 1st row has swizzle
  213. // and when we read it back in the postprocess kernel, the swizzle will not match.
  214. cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdQaccum>, 1024> smem_dqacc;
  215. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  216. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  217. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  218. cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutLSE>, 128> smem_lse;
  219. cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutLSE>, 128> smem_dpsum;
  220. };
  221. static constexpr int SharedStorageQdOSize = sizeof(decltype((TensorStorage{}).smem_q)) + sizeof(decltype((TensorStorage{}).smem_do)) + sizeof(decltype((TensorStorage{}).smem_ds)) + sizeof(decltype((TensorStorage{}).smem_dqacc));
  222. // Host side kernel arguments
  223. struct Arguments {
  224. Element const* ptr_Q;
  225. ShapeQKV const shape_Q;
  226. StrideQKV const stride_Q;
  227. Element const* ptr_K;
  228. ShapeQKV const shape_K;
  229. StrideQKV const stride_K;
  230. Element const* ptr_V;
  231. StrideQKV const stride_V;
  232. Element const* ptr_dO;
  233. StrideQKV const stride_dO;
  234. ElementAccum* ptr_dQaccum;
  235. ShapeQKV const shape_dQaccum;
  236. StrideQKV const stride_dQaccum;
  237. float const* ptr_LSE_log2;
  238. ShapeLSE const shape_LSE;
  239. StrideLSE const stride_LSE_log2;
  240. float const* ptr_dPsum;
  241. StrideLSE const stride_dPsum;
  242. float const softmax_scale;
  243. int num_batch;
  244. int* dq_semaphore;
  245. int const* cu_seqlens_q = nullptr;
  246. int const* cu_seqlens_k = nullptr;
  247. int const* seqused_k = nullptr;
  248. int const* seqused_v = nullptr;
  249. int window_size_left;
  250. int window_size_right;
  251. };
  252. // Device side kernel params
  253. struct Params {
  254. ShapeQKV const shape_Q;
  255. ShapeQKV const shape_K;
  256. ShapeQKV const shape_dQaccum;
  257. cutlass::FastDivmod qhead_per_khead_divmod;
  258. TMA_QdO tma_load_Q, tma_load_dO;
  259. TMA_K tma_load_K;
  260. TMA_V tma_load_V;
  261. TMA_add_dQ tma_add_dQ;
  262. TMA_LSE tma_load_LSE, tma_load_dPsum;
  263. float const* ptr_LSE_log2;
  264. ShapeLSE const shape_LSE;
  265. StrideLSE const stride_LSE_log2;
  266. float const* ptr_dPsum;
  267. StrideLSE const stride_dPsum;
  268. float const softmax_scale;
  269. float const softmax_scale_log2;
  270. int num_batch;
  271. int* dq_semaphore;
  272. int const* cu_seqlens_q = nullptr;
  273. int const* cu_seqlens_k = nullptr;
  274. int const* seqused_q = nullptr;
  275. int const* seqused_k = nullptr;
  276. int window_size_left;
  277. int window_size_right;
  278. };
  279. static Params
  280. to_underlying_arguments(Arguments const& args) {
  281. Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q);
  282. TMA_QdO tma_load_Q = make_tma_copy(
  283. GmemTiledCopyQdO{},
  284. mQ,
  285. SmemLayoutQ{}(_, _, _0{}),
  286. select<0, 2>(TileShape_MNK{}),
  287. size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
  288. Tensor mdO = make_tensor(make_gmem_ptr(args.ptr_dO), args.shape_Q, args.stride_dO);
  289. TMA_QdO tma_load_dO = make_tma_copy(
  290. GmemTiledCopyQdO{},
  291. mdO,
  292. SmemLayoutdO{}(_, _, _0{}),
  293. select<0, 2>(TileShape_MNK{}),
  294. size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
  295. Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K);
  296. TMA_K tma_load_K = make_tma_copy(
  297. GmemTiledCopyKV{},
  298. mK,
  299. SmemLayoutK{},
  300. select<1, 2>(TileShape_MNK{}),
  301. _1{}); // no mcast for KV
  302. Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_K, args.stride_V);
  303. TMA_V tma_load_V = make_tma_copy(
  304. GmemTiledCopyKV{},
  305. mV,
  306. SmemLayoutV{},
  307. select<1, 2>(TileShape_MNK{}),
  308. _1{}); // no mcast for KV
  309. Tensor mdQaccum = make_tensor(make_gmem_ptr(args.ptr_dQaccum), args.shape_dQaccum, args.stride_dQaccum);
  310. TMA_add_dQ tma_add_dQ = make_tma_copy(
  311. GmemTiledCopydQaccum{},
  312. mdQaccum,
  313. SmemLayoutdQaccumTMA{},
  314. TileShape_dQaccum{},
  315. _1{}); // no mcast for dQaccum
  316. Tensor mLSE = make_tensor(make_gmem_ptr(args.ptr_LSE_log2), args.shape_LSE, args.stride_LSE_log2);
  317. TMA_LSE tma_load_LSE = make_tma_copy(
  318. GmemTiledCopyLSE{},
  319. mLSE,
  320. select<0>(SmemLayoutLSE{}),
  321. select<0>(TileShape_MNK{}),
  322. _1{}); // no mcast for LSE
  323. Tensor mdPsum = make_tensor(make_gmem_ptr(args.ptr_dPsum), args.shape_LSE, args.stride_dPsum);
  324. TMA_LSE tma_load_dPsum = make_tma_copy(
  325. GmemTiledCopyLSE{},
  326. mdPsum,
  327. select<0>(SmemLayoutLSE{}),
  328. select<0>(TileShape_MNK{}),
  329. _1{}); // no mcast for dPsum
  330. if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); }
  331. return {args.shape_Q, args.shape_K, args.shape_dQaccum,
  332. cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))),
  333. tma_load_Q, tma_load_dO, tma_load_K, tma_load_V, tma_add_dQ, tma_load_LSE, tma_load_dPsum,
  334. args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum,
  335. args.softmax_scale, float(args.softmax_scale * M_LOG2E),
  336. args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k,
  337. args.seqused_k, args.seqused_v, args.window_size_left, args.window_size_right};
  338. }
  339. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  340. CUTLASS_DEVICE
  341. static void prefetch_tma_descriptors(Params const& params) {
  342. cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor());
  343. cute::prefetch_tma_descriptor(params.tma_load_dO.get_tma_descriptor());
  344. cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor());
  345. cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor());
  346. cute::prefetch_tma_descriptor(params.tma_load_LSE.get_tma_descriptor());
  347. cute::prefetch_tma_descriptor(params.tma_load_dPsum.get_tma_descriptor());
  348. cute::prefetch_tma_descriptor(params.tma_add_dQ.get_tma_descriptor());
  349. }
  350. CUTLASS_DEVICE
  351. int get_seqlen_q(Params const& params, int bidb) {
  352. if constexpr (!Varlen) {
  353. return get<0>(params.shape_Q);
  354. } else {
  355. return params.cu_seqlens_q == nullptr
  356. ? get<0>(params.shape_Q)
  357. : (params.seqused_q
  358. ? params.seqused_q[bidb]
  359. : params.cu_seqlens_q[bidb + 1] - params.cu_seqlens_q[bidb]
  360. );
  361. }
  362. }
  363. CUTLASS_DEVICE
  364. int get_seqlen_k(Params const& params, int bidb) {
  365. if constexpr (!Varlen) {
  366. return get<0>(params.shape_K);
  367. } else {
  368. return params.cu_seqlens_k == nullptr
  369. ? get<0>(params.shape_K)
  370. : (params.seqused_k
  371. ? params.seqused_k[bidb]
  372. : params.cu_seqlens_k[bidb + 1] - params.cu_seqlens_k[bidb]
  373. );
  374. }
  375. }
  376. CUTLASS_DEVICE
  377. int get_m_block_min(Params const& params, int n_block, int bidb) {
  378. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  379. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  380. if constexpr (Is_causal || Is_local) {
  381. int const seqlen_q = get_seqlen_q(params, bidb);
  382. int const seqlen_k = get_seqlen_k(params, bidb);
  383. return std::max(0, (n_block * kBlockN + seqlen_q - seqlen_k - params.window_size_right) / kBlockM);
  384. } else {
  385. return 0;
  386. }
  387. }
  388. CUTLASS_DEVICE
  389. int get_m_block_max(Params const& params, int n_block, int bidb) {
  390. static constexpr int kBlockM = get<0>(TileShape_MNK{});
  391. static constexpr int kBlockN = get<1>(TileShape_MNK{});
  392. int const seqlen_q = get_seqlen_q(params, bidb);
  393. int const seqlen_k = get_seqlen_k(params, bidb);
  394. int m_block_max = cute::ceil_div(seqlen_q, kBlockM);
  395. if constexpr (Is_local) {
  396. return std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + params.window_size_left, kBlockM));
  397. } else {
  398. return m_block_max;
  399. }
  400. }
  401. template <typename SchedulerPrefetch, typename SharedStorage>
  402. CUTLASS_DEVICE void
  403. load(Params const& params,
  404. MainloopPipeline pipeline_q,
  405. MainloopPipeline pipeline_do,
  406. PipelineState& smem_pipe_write,
  407. SharedStorage &shared_storage,
  408. SchedulerPrefetch const& scheduler_prefetch,
  409. cute::tuple<int32_t, int32_t, int32_t> block_coord,
  410. int work_idx
  411. ) {
  412. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_q.data()), SmemLayoutQ{});
  413. Tensor sdO = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_do.data()), SmemLayoutdO{});
  414. Tensor sK = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_k.data()), SmemLayoutK{});
  415. Tensor sV = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_v.data()), SmemLayoutV{});
  416. Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_lse.data()), SmemLayoutLSE{});
  417. Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_dpsum.data()), SmemLayoutLSE{});
  418. auto [n_block, bidh, bidb] = block_coord;
  419. int bidh_kv = params.qhead_per_khead_divmod.divide(bidh);
  420. // Prepare the TMA loads
  421. uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
  422. constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
  423. uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
  424. bool const is_varlen_q = Varlen && params.cu_seqlens_q != nullptr;
  425. bool const is_varlen_k = Varlen && params.cu_seqlens_k != nullptr;
  426. Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0);
  427. Tensor mdO = params.tma_load_dO.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0);
  428. Tensor mK = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0);
  429. Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0);
  430. Tensor mLSE = params.tma_load_LSE.get_tma_tensor(params.shape_LSE)(_, bidh, !is_varlen_q ? bidb : 0);
  431. Tensor mdPsum = params.tma_load_dPsum.get_tma_tensor(params.shape_LSE)(_, bidh, !is_varlen_q ? bidb : 0);
  432. int const offset_q = !is_varlen_q ? 0 : params.cu_seqlens_q[bidb];
  433. int const offset_k = !is_varlen_k ? 0 : params.cu_seqlens_k[bidb];
  434. int const offset_padded = !is_varlen_q ? 0 : (params.cu_seqlens_q[bidb] + bidb * 128) / 128 * 128;
  435. Tensor gQ = local_tile(domain_offset(make_coord(offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
  436. Tensor gdO = local_tile(domain_offset(make_coord(offset_q, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _)
  437. Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K)
  438. Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K)
  439. Tensor gLSE = local_tile(domain_offset(make_coord(offset_padded), mLSE), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _)
  440. Tensor gdPsum = local_tile(domain_offset(make_coord(offset_padded), mdPsum), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _)
  441. Tensor sK_x = make_tensor(sK.data(), make_layout(sK.layout(), Layout<_1>{}));
  442. Tensor gK_x = make_tensor(gK.data(), make_layout(gK.layout(), Layout<_1>{}));
  443. Tensor sV_x = make_tensor(sV.data(), make_layout(sV.layout(), Layout<_1>{}));
  444. Tensor gV_x = make_tensor(gV.data(), make_layout(gV.layout(), Layout<_1>{}));
  445. auto [tQgQ, tQsQ] = tma_partition(params.tma_load_Q, block_rank_in_cluster, Layout<ClusterShape>{},
  446. group_modes<0, 2>(sQ), group_modes<0, 2>(gQ)); // (TMA, k), (TMA, PIPE)
  447. auto [tdOgdO, tdOsdO] = tma_partition(params.tma_load_dO, block_rank_in_cluster, Layout<ClusterShape>{},
  448. group_modes<0, 2>(sdO), group_modes<0, 2>(gdO)); // (TMA, k), (TMA, PIPE)
  449. auto [tKgK, tKsK] = tma_partition(params.tma_load_K, _0{}, Layout<_1>{},
  450. group_modes<0, 2>(sK_x), group_modes<0, 2>(gK_x)); // (TMA), (TMA)
  451. auto [tVgV, tVsV] = tma_partition(params.tma_load_V, _0{}, Layout<_1>{},
  452. group_modes<0, 2>(sV_x), group_modes<0, 2>(gV_x)); // (TMA), (TMA)
  453. auto [tLSEgLSE, tLSEsLSE] = tma_partition(params.tma_load_LSE, _0{}, Layout<_1>{},
  454. sLSE, gLSE); // (TMA, k), (TMA, PIPE)
  455. auto [tLSEgdPsum, tLSEsdPsum] = tma_partition(params.tma_load_dPsum, _0{}, Layout<_1>{},
  456. sdPsum, gdPsum); // (TMA, k), (TMA, PIPE)
  457. uint16_t mcast_mask_qdo = 0;
  458. if constexpr (cute::is_same_v<GmemTiledCopyQdO, SM90_TMA_LOAD_MULTICAST>) {
  459. auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
  460. for (int n = 0; n < size<1>(block_layout); ++n) {
  461. mcast_mask_qdo |= (uint16_t(1) << block_layout(n, cluster_local_block_id.x, _0{}));
  462. }
  463. }
  464. int m_block_max = get_m_block_max(params, n_block, bidb);
  465. int m_block_min = get_m_block_min(params, n_block, bidb);
  466. int m_block = m_block_min;
  467. int lane_predicate = cute::elect_one_sync();
  468. // // Wait for the MMA warpgroups to say that smem_q is ready
  469. // cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::QueryEmpty) /*id*/);
  470. if (lane_predicate) {
  471. // Copy K tile and V tile from GMEM to SMEM.
  472. shared_storage.barrier_KV.arrive_and_expect_tx(TmaTransactionBytesK + TmaTransactionBytesV);
  473. copy(params.tma_load_K.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_KV), 0 /*mcast_mask*/), tKgK, tKsK);
  474. copy(params.tma_load_V.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_KV), 0 /*mcast_mask*/), tVgV, tVsV);
  475. pipeline_q.producer_acquire(smem_pipe_write);
  476. copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo), tQgQ(_, m_block), tQsQ(_, smem_pipe_write.index()));
  477. copy(params.tma_load_LSE.with(*pipeline_q.producer_get_barrier(smem_pipe_write), 0), tLSEgLSE(_, m_block), tLSEsLSE(_, smem_pipe_write.index()));
  478. #pragma unroll 2
  479. for (; m_block < m_block_max - 1; ++m_block) {
  480. pipeline_do.producer_acquire(smem_pipe_write);
  481. copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write), mcast_mask_qdo), tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write.index()));
  482. copy(params.tma_load_dPsum.with(*pipeline_do.producer_get_barrier(smem_pipe_write), 0), tLSEgdPsum(_, m_block), tLSEsdPsum(_, smem_pipe_write.index()));
  483. ++smem_pipe_write;
  484. pipeline_q.producer_acquire(smem_pipe_write);
  485. copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo), tQgQ(_, m_block + 1), tQsQ(_, smem_pipe_write.index()));
  486. copy(params.tma_load_LSE.with(*pipeline_q.producer_get_barrier(smem_pipe_write), 0), tLSEgLSE(_, m_block + 1), tLSEsLSE(_, smem_pipe_write.index()));
  487. }
  488. }
  489. scheduler_prefetch();
  490. if (lane_predicate) {
  491. pipeline_do.producer_acquire(smem_pipe_write);
  492. copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write), mcast_mask_qdo), tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write.index()));
  493. copy(params.tma_load_dPsum.with(*pipeline_do.producer_get_barrier(smem_pipe_write), 0), tLSEgdPsum(_, m_block), tLSEsdPsum(_, smem_pipe_write.index()));
  494. ++smem_pipe_write;
  495. }
  496. }
  497. /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
  498. CUTLASS_DEVICE void
  499. load_tail(MainloopPipeline pipeline_q, MainloopPipeline pipeline_do,
  500. PipelineState& smem_pipe_write) {
  501. // Need to copy since pipeline_q.producer_tail(smem_pipe_write) will increment smem_pipe_write
  502. PipelineState smem_pipe_write_do = smem_pipe_write;
  503. int lane_predicate = cute::elect_one_sync();
  504. // Issue the epilogue waits
  505. if (lane_predicate) {
  506. /* This helps avoid early exit of blocks in Cluster
  507. * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
  508. * then would just be acquired since the phase was still inverted from make_producer_start_state
  509. */
  510. pipeline_q.producer_tail(smem_pipe_write);
  511. pipeline_do.producer_tail(smem_pipe_write_do);
  512. }
  513. }
  514. template <typename SharedStorage>
  515. CUTLASS_DEVICE void
  516. store_dq(Params const& params,
  517. SharedStorage &shared_storage,
  518. cute::tuple<int32_t, int32_t, int32_t> block_coord
  519. ) {
  520. Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_dqacc.data()), SmemLayoutdQaccumTMA{});
  521. Tensor sdQnoswizzle = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_dqacc.data()), SmemLayoutdQaccumTMANoSwizzle{});
  522. auto [n_block, bidh, bidb] = block_coord;
  523. bool const is_varlen_q = Varlen && params.cu_seqlens_q != nullptr;
  524. // We reshaped dQaccum to have last dimension 32, so the offset needs to be multiplied by kHeadDim / 32
  525. int const offset_padded = !is_varlen_q ? 0 : ((params.cu_seqlens_q[bidb] + bidb * 128) / 128 * 128) * (kHeadDim / ElemsPerRowTMA);
  526. // Prepare the TMA loads
  527. Tensor mdQaccum = params.tma_add_dQ.get_tma_tensor(params.shape_dQaccum)(_, _, bidh, !is_varlen_q ? bidb : 0);
  528. Tensor gdQaccum = local_tile(domain_offset(make_coord(offset_padded, _0{}), mdQaccum), TileShape_dQaccum{}, make_coord(_, _0{})); // (M, K, _)
  529. auto block_tma_dQ = params.tma_add_dQ.get_slice(_0{});
  530. Tensor tdQgdQ = block_tma_dQ.partition_D(gdQaccum); // (TMA, TMA_M, TMA_K)
  531. Tensor tdQsdQ = block_tma_dQ.partition_S(sdQ); // (TMA, TMA_M, TMA_K)
  532. int m_block_max = get_m_block_max(params, n_block, bidb);
  533. int m_block_min = get_m_block_min(params, n_block, bidb);
  534. int m_block = m_block_min;
  535. int const num_batch = params.num_batch;
  536. int const num_head = get<2>(params.shape_Q);
  537. int *lock_ptr = !Deterministic ? nullptr : params.dq_semaphore + bidb * num_head + bidh;
  538. using Barrier = cutlass::GenericBarrier<cutlass::detail::SyncwarpSync>;
  539. int lane_predicate = cute::elect_one_sync();
  540. #pragma unroll 2
  541. for (; m_block < m_block_max; ++m_block) {
  542. if constexpr (Deterministic) {
  543. Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block);
  544. }
  545. cutlass::arch::NamedBarrier::sync(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQFull) /*id*/); // sdQ full, to be written to gmem
  546. if (lane_predicate) {
  547. cute::copy(params.tma_add_dQ, tdQsdQ, tdQgdQ(_, _, _, m_block));
  548. tma_store_arrive();
  549. }
  550. tma_store_wait<0>();
  551. if constexpr (Deterministic) {
  552. Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head);
  553. }
  554. cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQEmpty) /*id*/); // sdQ empty, ready to be written to
  555. }
  556. if constexpr (Is_local && Deterministic) {
  557. constexpr int kBlockM = get<0>(TileShape_MNK{});
  558. int const seqlen_q = get_seqlen_q(params, bidb);
  559. int const m_block_global_max = cute::ceil_div(seqlen_q, kBlockM);
  560. #pragma unroll 2
  561. for (; m_block < m_block_global_max; ++m_block) {
  562. Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head);
  563. }
  564. }
  565. }
  566. CUTLASS_DEVICE void
  567. mma_init() {
  568. // // Tell producer (warp 0) that smem_q is ready
  569. // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::QueryEmpty) /*id*/);
  570. int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
  571. if (cutlass::canonical_warp_group_idx() == 1 && warp_idx_in_warpgroup == 0) {
  572. cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQEmpty) /*id*/); // sdQ empty, ready to be written to
  573. }
  574. }
  575. template <typename SharedStorage, typename FrgTensordKV>
  576. CUTLASS_DEVICE void
  577. mma(Params const& params,
  578. MainloopPipeline pipeline_q,
  579. MainloopPipeline pipeline_do,
  580. PipelineState& smem_pipe_read,
  581. FrgTensordKV& tdKrdK,
  582. FrgTensordKV& tdVrdV,
  583. int thread_idx,
  584. int work_idx,
  585. cute::tuple<int32_t, int32_t, int32_t> block_coord,
  586. SharedStorage& shared_storage
  587. ) {
  588. static_assert(is_rmem<FrgTensordKV>::value, "dK and dV tensor must be rmem resident.");
  589. Tensor sQ = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_q.data()), SmemLayoutQ{});
  590. Tensor sdO = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_do.data()), SmemLayoutdO{});
  591. Tensor sK = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_k.data()), SmemLayoutK{});
  592. Tensor sV = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_v.data()), SmemLayoutV{});
  593. Tensor sQt = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_q.data()), SmemLayoutQt{});
  594. Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_do.data()), SmemLayoutdOt{});
  595. Tensor sKt = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_k.data()), SmemLayoutKt{});
  596. Tensor sdS = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_ds.data()), SmemLayoutdS{});
  597. Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_ds.data()), SmemLayoutdSt{});
  598. Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_dqacc.data()), SmemLayoutdQaccum{});
  599. Tensor sLSEMma = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_lse.data()), SmemLayoutLSEMma{});
  600. Tensor sdPsumMma = make_tensor(make_smem_ptr(shared_storage.mainloop.smem_dpsum.data()), SmemLayoutLSEMma{});
  601. static_assert(stride<0>(typename TiledMmaSdP::ALayout{}) == 0 and
  602. stride<0>(typename TiledMmaSdP::BLayout{}) == 0 and
  603. size<0>(typename TiledMmaSdP::ALayout{}) == cutlass::NumThreadsPerWarpGroup and
  604. size<0>(typename TiledMmaSdP::BLayout{}) == cutlass::NumThreadsPerWarpGroup,
  605. "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
  606. constexpr int MmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup;
  607. Layout warp_group_thread_layout = make_layout(make_shape(Int<MmaWarpGroups>{}),
  608. make_stride(Int<cutlass::NumThreadsPerWarpGroup>{}));
  609. Layout warp_group_thread_layout_dq = make_layout(make_shape(Int<NumdQWarpGroups>{}),
  610. make_stride(Int<cutlass::NumThreadsPerWarpGroup>{}));
  611. int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0);
  612. TiledMmaSdP tiled_mma_SdP;
  613. TiledMmadKV tiled_mma_dKV;
  614. TiledMmadQ tiled_mma_dQ;
  615. static_assert(!dKV_swapAB);
  616. auto wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx));
  617. auto thread_mma_SdP = tiled_mma_SdP.get_thread_slice(thread_idx);
  618. auto wg_mma_dKV = tiled_mma_dKV.get_slice(warp_group_thread_layout(warp_group_idx));
  619. auto wg_mma_dQ = tiled_mma_dQ.get_slice(!Varlen ? warp_group_thread_layout_dq(NumdQWarpGroups == 2 ? warp_group_idx : 0) : thread_idx);
  620. // auto wg_mma_dQ = tiled_mma_dQ.get_thread_slice(thread_idx);
  621. auto smem_tiled_copy_PdS = make_tiled_copy_C(SmemCopyAtomPdS{}, tiled_mma_SdP);
  622. auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(thread_idx);
  623. Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdSt); // ((Atom,AtomNum),PIPE_M,PIPE_N)
  624. R2STiledCopydQaccum r2s_tiled_copy_dQaccum;
  625. // auto r2s_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_thread_slice(thread_idx);
  626. auto r2s_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_thread_slice(NumdQWarpGroups == 2 ? thread_idx : thread_idx % cutlass::NumThreadsPerWarpGroup);
  627. Tensor tdQsdQaccum = r2s_thr_copy_dQaccum.partition_D(sdQ);
  628. // Allocate "fragments/descriptors"
  629. Tensor tSrQ = wg_mma_SdP.partition_fragment_B(sQ);
  630. Tensor tSrK = wg_mma_SdP.partition_fragment_A(sK);
  631. Tensor tdPrdO = wg_mma_SdP.partition_fragment_B(sdO);
  632. Tensor tdPrV = wg_mma_SdP.partition_fragment_A(sV);
  633. Tensor tdVrdO = wg_mma_dKV.partition_fragment_B(sdOt);
  634. Tensor tdKrQ = wg_mma_dKV.partition_fragment_B(sQt);
  635. int n_block = get<0>(block_coord);
  636. int bidh = get<1>(block_coord);
  637. int bidb = get<2>(block_coord);
  638. int const seqlen_q = get_seqlen_q(params, bidb);
  639. int const seqlen_k = get_seqlen_k(params, bidb);
  640. int m_block_max = get_m_block_max(params, n_block, bidb);
  641. int m_block_min = get_m_block_min(params, n_block, bidb);
  642. int m_block = m_block_min;
  643. // thread_mma_SdP.partition_C(sLSEMma) has shape ((2, 2, V), MMA_M, MMA_N, PIPE), we only take the row indices.
  644. Tensor tLSEsLSE = thread_mma_SdP.partition_C(sLSEMma)(make_coord(_, _0{}, _), _0{}, _0{}, _); // (2, V, PIPE)
  645. Tensor tLSEsdPsum = thread_mma_SdP.partition_C(sdPsumMma)(make_coord(_, _0{}, _), _0{}, _0{}, _);
  646. clear(tdKrdK);
  647. clear(tdVrdV);
  648. // tiled_mma_dKV.accumulate_ = GMMA::ScaleOut::Zero;
  649. cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_KV.try_wait(work_idx % 2));
  650. if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_KV.wait(work_idx % 2); }
  651. auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
  652. auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
  653. pipeline.consumer_wait(smem_pipe_read, barrier_token);
  654. };
  655. auto compute_dQ = [&]() {
  656. static_assert(!Mma_dQ_is_RS);
  657. // SMEM fence to make sure sP is written before it's read by WGMMA
  658. cutlass::arch::fence_view_async_shared();
  659. cutlass::arch::NamedBarrier::sync(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQEmpty) /*id*/); // sdQ empty, ready to be written to
  660. Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 2, !dQ_swapAB ? 2 : 0>(TileShape_MNK{}));
  661. if constexpr (!dQ_swapAB) {
  662. Tensor tdQrdS = wg_mma_dQ.partition_fragment_A(sdS);
  663. Tensor tdQrK = wg_mma_dQ.partition_fragment_B(sKt);
  664. flash::gemm</*zero_init=*/true, /*wg_wait=*/1>(tiled_mma_dQ, tdQrdS(_, _, _, smem_pipe_read.index()), tdQrK, tdQrdQ);
  665. } else {
  666. Tensor tdQrdS = wg_mma_dQ.partition_fragment_B(sdS);
  667. Tensor tdQrK = wg_mma_dQ.partition_fragment_A(sKt);
  668. flash::gemm</*zero_init=*/true, /*wg_wait=*/1>(tiled_mma_dQ, tdQrK, tdQrdS(_, _, _, smem_pipe_read.index()), tdQrdQ);
  669. }
  670. pipeline_q.consumer_release(smem_pipe_read); // release Q
  671. warpgroup_wait<0>();
  672. Tensor taccdQrdQ = r2s_thr_copy_dQaccum.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N)
  673. cute::copy(r2s_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum);
  674. cutlass::arch::fence_view_async_shared();
  675. cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + cutlass::NumThreadsPerWarp, static_cast<int>(BwdNamedBarriers::dQFull) /*id*/); // sdQ full, to be written to gmem
  676. };
  677. // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64
  678. // this helps quite a bit to not have to do causal masking for most of the iterations.
  679. if constexpr (Is_causal) {
  680. static constexpr int n_masking_steps = cute::ceil_div(kBlockN, kBlockM) + 1;
  681. CUTLASS_PRAGMA_NO_UNROLL
  682. for (; m_block < std::min(m_block_max, m_block_min + n_masking_steps); ++m_block) {
  683. Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select<1, 0>(TileShape_MNK{}));
  684. pipeline_q.consumer_wait(smem_pipe_read);
  685. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma_SdP, tSrK, tSrQ(_, _, _, smem_pipe_read.index()), tSrS);
  686. Tensor tLSErLSE = make_fragment_like(tLSEsLSE(_, _, _0{}));
  687. cute::copy(tLSEsLSE(_, _, smem_pipe_read.index()), tLSErLSE);
  688. Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select<1, 0>(TileShape_MNK{}));
  689. pipeline_do.consumer_wait(smem_pipe_read);
  690. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma_SdP, tdPrV, tdPrdO(_, _, _, smem_pipe_read.index()), tdPrdP);
  691. warpgroup_wait<1>();
  692. Tensor cS = cute::make_identity_tensor(select<1, 0>(TileShape_MNK{}));
  693. Tensor taccScS = thread_mma_SdP.partition_C(cS);
  694. int causal_row_offset = 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM;
  695. #pragma unroll
  696. for (int i = 0; i < size(tSrS); ++i) {
  697. if (int(get<0>(taccScS(i))) >=
  698. std::min(int(get<1>(taccScS(i))) + causal_row_offset, seqlen_k - n_block * kBlockN)) {
  699. tSrS(i) = -INFINITY;
  700. }
  701. }
  702. // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
  703. Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout()));
  704. flash::scale_apply_exp2</*Scale_max=*/false, /*Check_inf=*/false>(scores, group_modes<0, 2>(tLSErLSE), params.softmax_scale_log2);
  705. Tensor tLSErdPsum = make_fragment_like(tLSEsdPsum(_, _, _0{}));
  706. cute::copy(tLSEsdPsum(_, _, smem_pipe_read.index()), tLSErdPsum);
  707. // Convert scores from fp32 to fp16/bf16
  708. Tensor rP = flash::convert_type<Element>(tSrS);
  709. warpgroup_wait<0>();
  710. // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
  711. Tensor dS = make_tensor(tdPrdP.data(), scores.layout());
  712. for (int mi = 0; mi < size<0>(dS); ++mi) {
  713. #pragma unroll
  714. for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - tLSErdPsum(mi)); }
  715. }
  716. Tensor rdS = flash::convert_type<Element>(tdPrdP);
  717. // Because of double buffering on dS, we don't need to sync here.
  718. // Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ.
  719. // But because both WGs have to sync at the end of the loop and double buffering, this race condition
  720. // is not possible.
  721. Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N)
  722. cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS(_, _, _, smem_pipe_read.index()));
  723. Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs<TiledMmadKV>(tSrS.layout()));
  724. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma_dKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read.index()), tdVrdV);
  725. Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<TiledMmadKV>(tdPrdP.layout()));
  726. flash::gemm</*zero_init=*/false, /*wg_wait=*/1>(tiled_mma_dKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK);
  727. pipeline_do.consumer_release(smem_pipe_read); // release dO
  728. compute_dQ();
  729. ++smem_pipe_read;
  730. }
  731. }
  732. CUTLASS_PRAGMA_NO_UNROLL
  733. for (; m_block < m_block_max; ++m_block) {
  734. Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select<1, 0>(TileShape_MNK{}));
  735. pipeline_q.consumer_wait(smem_pipe_read);
  736. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma_SdP, tSrK, tSrQ(_, _, _, smem_pipe_read.index()), tSrS);
  737. Tensor tLSErLSE = make_fragment_like(tLSEsLSE(_, _, _0{}));
  738. cute::copy(tLSEsLSE(_, _, smem_pipe_read.index()), tLSErLSE);
  739. Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select<1, 0>(TileShape_MNK{}));
  740. pipeline_do.consumer_wait(smem_pipe_read);
  741. flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma_SdP, tdPrV, tdPrdO(_, _, _, smem_pipe_read.index()), tdPrdP);
  742. warpgroup_wait<1>();
  743. Tensor cS = cute::make_identity_tensor(select<1, 0>(TileShape_MNK{}));
  744. Tensor taccScS = thread_mma_SdP.partition_C(cS);
  745. if constexpr (!Is_local) {
  746. #pragma unroll
  747. for (int i = 0; i < size(tSrS); ++i) {
  748. if (int(get<0>(taccScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; }
  749. }
  750. } else {
  751. int local_row_offset_right = 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM + params.window_size_right;
  752. int local_row_offset_left = seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM - params.window_size_left;
  753. #pragma unroll
  754. for (int i = 0; i < size(tSrS); ++i) {
  755. if ((int(get<0>(taccScS(i))) >= std::min(int(get<1>(taccScS(i))) + local_row_offset_right, seqlen_k - n_block * kBlockN)) ||
  756. (int(get<0>(taccScS(i))) < std::max(int(get<1>(taccScS(i))) + local_row_offset_left, 0))) {
  757. tSrS(i) = -INFINITY;
  758. }
  759. }
  760. }
  761. // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
  762. Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout()));
  763. // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tLSErLSE); }
  764. // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(scores); }
  765. flash::scale_apply_exp2</*Scale_max=*/false, /*Check_inf=*/false>(scores, group_modes<0, 2>(tLSErLSE), params.softmax_scale_log2);
  766. // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(scores); }
  767. Tensor tLSErdPsum = make_fragment_like(tLSEsdPsum(_, _, _0{}));
  768. cute::copy(tLSEsdPsum(_, _, smem_pipe_read.index()), tLSErdPsum);
  769. // Convert scores from fp32 to fp16/bf16
  770. Tensor rP = flash::convert_type<Element>(tSrS);
  771. warpgroup_wait<0>();
  772. // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N))
  773. Tensor dS = make_tensor(tdPrdP.data(), scores.layout());
  774. #pragma unroll
  775. for (int mi = 0; mi < size<0>(dS); ++mi) {
  776. #pragma unroll
  777. for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - tLSErdPsum(mi)); }
  778. }
  779. // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dS); }
  780. Tensor rdS = flash::convert_type<Element>(tdPrdP);
  781. Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N)
  782. cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS(_, _, _, smem_pipe_read.index()));
  783. Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs<TiledMmadKV>(tSrS.layout()));
  784. flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma_dKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read.index()), tdVrdV);
  785. Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs<TiledMmadKV>(tdPrdP.layout()));
  786. flash::gemm</*zero_init=*/false, /*wg_wait=*/1>(tiled_mma_dKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK);
  787. pipeline_do.consumer_release(smem_pipe_read); // release dO
  788. compute_dQ();
  789. ++smem_pipe_read;
  790. }
  791. // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tdVrdV); }
  792. #pragma unroll
  793. for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.softmax_scale; }
  794. }
  795. };
  796. } // namespace flash