kernel_traits.h 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/algorithm/copy.hpp"
  6. #include "cute/atom/mma_atom.hpp"
  7. #include "cutlass/gemm/collective/collective_builder.hpp"
  8. #include "cutlass/cutlass.h"
  9. #include "cutlass/layout/layout.h"
  10. #include "cutlass/numeric_types.h"
  11. #include "cutlass/pipeline/pipeline.hpp"
  12. using namespace cute;
  13. template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
  14. class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
  15. struct SharedStorageQKVO {
  16. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
  17. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
  18. union {
  19. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
  20. cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
  21. };
  22. struct {
  23. cute::uint64_t tma_load_mbar[4]; // 4 TMA barriers pre-allocated for usage.
  24. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  25. cutlass::arch::ClusterBarrier barrier_O;
  26. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  27. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  28. int tile_count_semaphore;
  29. };
  30. };
  31. // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
  32. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool Is_Q_in_regs_=false,
  33. int kClusterM_ = 1, typename elem_type=cutlass::half_t>
  34. struct Flash_fwd_kernel_traits {
  35. using Element = elem_type;
  36. using ElementAccum = float;
  37. using index_t = int64_t;
  38. using ElementO = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(cutlass::half_t{}, Element{}));
  39. // The number of threads.
  40. static constexpr int kNWarps = kNWarps_;
  41. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  42. static constexpr bool Is_Q_in_regs = Is_Q_in_regs_;
  43. static_assert(kNWarps_ == 4 || kNWarps_ == 8 || kNWarps_ == 12 || kNWarps_ == 16);
  44. static constexpr bool Is_WS = kNWarps_ >= 12;
  45. static_assert(!(Is_WS && Is_Q_in_regs), "Warp-specialization does not support Q in registers");
  46. static constexpr int kBlockM = kBlockM_;
  47. static constexpr int kBlockN = kBlockN_;
  48. static constexpr int kHeadDim = kHeadDim_;
  49. static_assert(kHeadDim % 32 == 0);
  50. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  51. static constexpr int kClusterM = kClusterM_;
  52. using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
  53. static constexpr int kStages = kStages_;
  54. using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
  55. using TiledMma0 = decltype(cute::make_tiled_mma(
  56. std::conditional_t<
  57. Is_Q_in_regs,
  58. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>()),
  59. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>())
  60. >{},
  61. AtomLayoutMNK{}));
  62. using TiledMma1 = decltype(cute::make_tiled_mma(
  63. cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{})),
  64. GMMA::Major::K, cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(
  65. GMMA::Major::K, GMMA::Major::MN)>(),
  66. AtomLayoutMNK{}));
  67. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  68. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  69. using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
  70. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  71. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  72. using SmemLayoutK =
  73. decltype(tile_to_shape(SmemLayoutAtomK{},
  74. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  75. using SmemLayoutAtomVFp16 = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  76. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  77. using SmemLayoutVFp16 =
  78. decltype(tile_to_shape(SmemLayoutAtomVFp16{},
  79. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  80. using SmemLayoutAtomVFp8 = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  81. decltype(cute::get<2>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  82. using SmemLayoutVFp8 =
  83. decltype(tile_to_shape(SmemLayoutAtomVFp8{},
  84. make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{})));
  85. using SmemLayoutV = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(SmemLayoutVFp8{}, SmemLayoutVFp16{}));
  86. // Note this is the transpose in terms of the view, not in terms of memory.
  87. using SmemLayoutVtFp16 =
  88. decltype(cute::composition(SmemLayoutVFp16{},
  89. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
  90. make_stride(get<1>(TileShape_MNK{}), _1{}, Int<size(SmemLayoutVFp16{}(_, _, _0{}))>{}))));
  91. using SmemLayoutVt = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(SmemLayoutVFp8{}, SmemLayoutVtFp16{}));
  92. using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, ElementO,
  93. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  94. using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
  95. using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, ElementO>;
  96. using SharedStorage = SharedStorageQKVO<kStages, Element, Element, ElementO, SmemLayoutQ,
  97. SmemLayoutK, SmemLayoutV, SmemLayoutO>;
  98. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  99. using PipelineState = typename cutlass::PipelineState<kStages>;
  100. // using BarrierType = typename MainloopPipeline::ProducerBarrierType;
  101. };
  102. ////////////////////////////////////////////////////////////////////////////////////////////////////
  103. template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  104. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  105. class SmemLayoutdK, class SmemLayoutdV>
  106. struct SharedStorageQKVdOdKV;
  107. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  108. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  109. class SmemLayoutdK, class SmemLayoutdV>
  110. struct SharedStorageQKVdOdKV<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  111. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV> {
  112. struct {
  113. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  114. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  115. union {
  116. struct {
  117. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  118. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  119. };
  120. struct {
  121. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  122. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  123. };
  124. };
  125. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  126. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  127. };
  128. struct {
  129. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  130. cutlass::arch::ClusterTransactionBarrier barrier_K;
  131. cutlass::arch::ClusterTransactionBarrier barrier_V;
  132. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  133. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  134. };
  135. };
  136. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  137. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  138. class SmemLayoutdK, class SmemLayoutdV>
  139. struct SharedStorageQKVdOdKV<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  140. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV> {
  141. struct {
  142. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  143. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  144. union {
  145. struct {
  146. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  147. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  148. };
  149. struct {
  150. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  151. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  152. };
  153. };
  154. union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
  155. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  156. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  157. };
  158. };
  159. struct {
  160. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  161. cutlass::arch::ClusterTransactionBarrier barrier_K;
  162. cutlass::arch::ClusterTransactionBarrier barrier_V;
  163. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  164. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  165. };
  166. };
  167. template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  168. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
  169. class SmemLayoutdK, class SmemLayoutdV>
  170. struct SharedStorageQKVdOdKVWS;
  171. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  172. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
  173. class SmemLayoutdK, class SmemLayoutdV>
  174. struct SharedStorageQKVdOdKVWS<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  175. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV> {
  176. struct {
  177. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  178. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  179. union {
  180. struct {
  181. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  182. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  183. };
  184. struct {
  185. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  186. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  187. };
  188. };
  189. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  190. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  191. cute::array_aligned<float, cute::cosize_v<SmemLayoutdQacc>> smem_dqacc;
  192. cute::array_aligned<float, 128> smem_lse;
  193. cute::array_aligned<float, 128> smem_dpsum;
  194. };
  195. struct {
  196. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  197. cutlass::arch::ClusterTransactionBarrier barrier_K;
  198. cutlass::arch::ClusterTransactionBarrier barrier_V;
  199. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  200. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  201. };
  202. };
  203. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  204. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
  205. class SmemLayoutdK, class SmemLayoutdV>
  206. struct SharedStorageQKVdOdKVWS<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  207. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV> {
  208. struct {
  209. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  210. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  211. union {
  212. struct {
  213. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  214. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  215. };
  216. struct {
  217. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  218. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  219. };
  220. };
  221. union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
  222. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  223. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  224. };
  225. cute::array_aligned<float, cute::cosize_v<SmemLayoutdQacc>> smem_dqacc;
  226. cute::array_aligned<float, 128> smem_lse;
  227. cute::array_aligned<float, 128> smem_dpsum;
  228. };
  229. struct {
  230. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  231. cutlass::arch::ClusterTransactionBarrier barrier_K;
  232. cutlass::arch::ClusterTransactionBarrier barrier_V;
  233. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  234. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  235. };
  236. };
  237. template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  238. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  239. class SmemLayoutdQ>
  240. struct SharedStorageQKVdOdKVSeqqPar;
  241. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  242. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  243. class SmemLayoutdQ>
  244. struct SharedStorageQKVdOdKVSeqqPar<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  245. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQ> {
  246. struct {
  247. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  248. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  249. union {
  250. struct {
  251. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  252. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  253. };
  254. struct {
  255. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
  256. };
  257. };
  258. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  259. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  260. };
  261. struct {
  262. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  263. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  264. cutlass::arch::ClusterTransactionBarrier barrier_dO;
  265. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  266. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  267. };
  268. };
  269. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  270. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  271. class SmemLayoutdQ>
  272. struct SharedStorageQKVdOdKVSeqqPar<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  273. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQ> {
  274. struct {
  275. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  276. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  277. union {
  278. struct {
  279. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  280. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  281. };
  282. struct {
  283. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
  284. };
  285. };
  286. union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
  287. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  288. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  289. };
  290. };
  291. struct {
  292. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  293. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  294. cutlass::arch::ClusterTransactionBarrier barrier_dO;
  295. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  296. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  297. };
  298. };
  299. ////////////////////////////////////////////////////////////////////////////////////////////////////
  300. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
  301. bool SdP_swapAB_, bool dKV_swapAB_, bool dQ_swapAB_,
  302. int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
  303. int kClusterN_ = 1, typename elem_type=cutlass::half_t>
  304. struct Flash_bwd_kernel_traits {
  305. using Element = elem_type;
  306. using ElementAccum = float;
  307. using index_t = int64_t;
  308. // The number of threads.
  309. static constexpr int kNWarps = kNWarps_;
  310. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  311. static constexpr int kNThreadsNonWS = 8 * cutlass::NumThreadsPerWarp;
  312. // static constexpr int kNThreadsdQ = cutlass::NumThreadsPerWarpGroup;
  313. static constexpr int kNThreadsdQ = 2 * cutlass::NumThreadsPerWarpGroup;
  314. static_assert(kNWarps_ == 8 || kNWarps_ == 12);
  315. static constexpr bool Is_WS = kNWarps_ >= 12;
  316. static constexpr int kBlockM = kBlockM_;
  317. static constexpr int kBlockN = kBlockN_;
  318. static constexpr int kHeadDim = kHeadDim_;
  319. static_assert(kHeadDim % 32 == 0);
  320. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  321. static constexpr int kClusterN = kClusterN_;
  322. using ClusterShape_MNK = Shape<_1, Int<kClusterN>, _1>;
  323. static constexpr int kStages = 2;
  324. static constexpr bool SdP_swapAB = SdP_swapAB_;
  325. static constexpr bool dKV_swapAB = dKV_swapAB_;
  326. static constexpr bool dQ_swapAB = dQ_swapAB_;
  327. static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV
  328. static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS
  329. using TileShapeAtomSdP = std::conditional_t<
  330. !SdP_swapAB,
  331. Shape<Int<kBlockM>, Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kHeadDim>>,
  332. Shape<Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kBlockM>, Int<kHeadDim>>
  333. >;
  334. using AtomLayoutSdP = std::conditional_t<
  335. !SdP_swapAB,
  336. Layout<Shape<Int<AtomLayoutMSdP>, Int<2 / AtomLayoutMSdP>, _1>>,
  337. Layout<Shape<Int<2 / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>
  338. >;
  339. using TiledMmaSdP = decltype(cute::make_tiled_mma(
  340. cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
  341. AtomLayoutSdP{}));
  342. using TileShapeAtomdKV = std::conditional_t<
  343. !dKV_swapAB,
  344. Shape<Int<kBlockN>, Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockM>>,
  345. Shape<Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockN>, Int<kBlockM>>
  346. >;
  347. using AtomLayoutdKV = std::conditional_t<
  348. !dKV_swapAB,
  349. Layout<Shape<Int<AtomLayoutNdKV>, Int<2 / AtomLayoutNdKV>, _1>>,
  350. Layout<Shape<Int<2 / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>
  351. >;
  352. using TiledMmadKV = decltype(cute::make_tiled_mma(
  353. std::conditional_t<
  354. !SdP_swapAB,
  355. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::MN, GMMA::Major::MN>()),
  356. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::K, GMMA::Major::MN>())
  357. >{},
  358. AtomLayoutdKV{}));
  359. using TileShapeAtomdQ = std::conditional_t<
  360. !dQ_swapAB,
  361. Shape<Int<kBlockM>, Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockN>>,
  362. Shape<Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockM>, Int<kBlockN>>
  363. // Shape<Int<kBlockM>, Int<kHeadDim >, Int<kBlockN>>,
  364. // Shape<Int<kHeadDim>, Int<kBlockM>, Int<kBlockN>>
  365. >;
  366. using AtomLayoutdQ = std::conditional_t<
  367. !dQ_swapAB,
  368. Layout<Shape<Int<AtomLayoutMdQ>, Int<2 / AtomLayoutMdQ>, _1>>,
  369. Layout<Shape<Int<2 / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>
  370. // Layout<Shape<Int<1>, Int<1>, _1>>,
  371. // Layout<Shape<Int<1>, Int<1>, _1>>
  372. >;
  373. static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN;
  374. static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K;
  375. using TiledMmadQ = decltype(cute::make_tiled_mma(
  376. std::conditional_t<
  377. !dQ_swapAB,
  378. std::conditional_t<
  379. Mma_dQ_is_RS,
  380. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>()),
  381. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>())
  382. >,
  383. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::MN, GMMA::Major::K>())
  384. >{},
  385. AtomLayoutdQ{}));
  386. using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
  387. using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
  388. using GmemTiledCopydKV = cute::SM90_TMA_STORE;
  389. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  390. static constexpr bool Has_cp_async = true;
  391. #else
  392. static constexpr bool Has_cp_async = false;
  393. #endif
  394. // For the dot_do_o preprocessing kernel
  395. using Gmem_copy_struct = std::conditional_t<
  396. Has_cp_async,
  397. SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
  398. DefaultCopy
  399. >;
  400. static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
  401. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  402. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
  403. // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
  404. // to affect speed in practice.
  405. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
  406. static_assert(kNThreadsNonWS % kGmemThreadsPerRow == 0, "kNThreadsNonWS must be a multiple of kGmemThreadsPerRow");
  407. using GmemLayoutAtom = Layout<Shape <Int<kNThreadsNonWS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  408. Stride<Int<kGmemThreadsPerRow>, _1>>;
  409. using GmemLayoutAtomdQ = Layout<Shape <Int<kNThreadsdQ / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  410. Stride<Int<kGmemThreadsPerRow>, _1>>;
  411. using GmemTiledCopydO = decltype(
  412. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  413. GmemLayoutAtom{},
  414. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  415. using GmemTiledCopydQ = decltype(
  416. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  417. GmemLayoutAtomdQ{},
  418. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  419. using GmemLayoutAtomdQaccum = std::conditional_t<
  420. kBlockKSmem == 32,
  421. Layout<Shape <Int<kNThreadsdQ / 8>, _8>, // Thread layout, 8 threads per row
  422. Stride< _8, _1>>,
  423. Layout<Shape <Int<kNThreadsdQ / 16>, _16>, // Thread layout, 16 threads per row
  424. Stride< _16, _1>>
  425. >;
  426. using GmemTiledCopydQaccum = decltype(
  427. make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
  428. GmemLayoutAtomdQaccum{},
  429. Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
  430. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  431. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  432. using SmemLayoutQ =
  433. decltype(tile_to_shape(SmemLayoutAtomQ{},
  434. make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  435. using SmemLayoutdO = SmemLayoutQ;
  436. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  437. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  438. using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
  439. using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  440. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  441. using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{})));
  442. using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  443. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  444. using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{})));
  445. using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  446. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  447. using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, select<0, 1>(TileShape_MNK{})));
  448. // using SmemLayoutAtomdQacc = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, ElementAccum,
  449. // decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  450. // using SmemLayoutdQacc = decltype(tile_to_shape(SmemLayoutAtomdQacc{}, select<0, 2>(TileShape_MNK{})));
  451. // Note this is the transpose in terms of the view, not in terms of memory.
  452. using SmemLayoutQt =
  453. decltype(cute::composition(SmemLayoutQ{},
  454. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
  455. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
  456. using SmemLayoutdOt =
  457. decltype(cute::composition(SmemLayoutdO{},
  458. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
  459. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
  460. using SmemLayoutKt =
  461. decltype(cute::composition(SmemLayoutK{},
  462. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  463. make_stride(Int<kBlockN>{}, _1{}))));
  464. using SmemLayoutPt =
  465. decltype(cute::composition(SmemLayoutP{},
  466. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  467. make_stride(Int<kBlockM>{}, _1{}))));
  468. using SmemLayoutdSt =
  469. decltype(cute::composition(SmemLayoutdS{},
  470. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  471. make_stride(Int<kBlockM>{}, _1{}))));
  472. // using SmemLayoutdQacct =
  473. // decltype(cute::composition(SmemLayoutdQacc{},
  474. // make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  475. // make_stride(Int<kBlockM>{}, _1{}))));
  476. using SmemLayoutdK = SmemLayoutK;
  477. using SmemLayoutdV = SmemLayoutV;
  478. using SmemLayoutdKt = SmemLayoutKt;
  479. using SmemLayoutdVt = SmemLayoutKt;
  480. static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
  481. using SmemLayoutAtomdQ = decltype(
  482. // composition(Swizzle<kSwizzle, 3, 3>{},
  483. composition(Swizzle<3, 3, 3>{},
  484. Layout<Shape<Int<kNThreadsdQ / 32>, Int<32>>,
  485. Stride<Int<32>, _1>>{}));
  486. using SmemLayoutdQ = decltype(tile_to_shape(
  487. SmemLayoutAtomdQ{},
  488. make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
  489. using SmemLayoutdQt =
  490. decltype(cute::composition(SmemLayoutdQ{},
  491. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  492. make_stride(Int<kBlockM>{}, _1{}))));
  493. static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
  494. using SmemLayoutAtomdQaccTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, ElementAccum,
  495. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  496. using SmemLayoutdQaccTMA = decltype(tile_to_shape(SmemLayoutAtomdQaccTMA{}, select<0, 2>(TileShape_MNK{})));
  497. using SmemLayoutdQacc = SmemLayoutdQ;
  498. using SmemLayoutdQacct = SmemLayoutdQt;
  499. using SmemLayoutdQacc2 = decltype(tile_to_shape(
  500. SmemLayoutAtomdQ{},
  501. make_shape(Int<kBlockM>{}, Int<kHeadDim>{}, _2{})));
  502. // using SmemLayoutdQacc = decltype(tile_to_shape(SmemLayoutAtomdQacc{}, select<0, 2>(TileShape_MNK{})));
  503. // using SmemLayoutdQacct =
  504. // decltype(cute::composition(SmemLayoutdQacc{},
  505. // make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  506. // make_stride(Int<kBlockM>{}, _1{}))));
  507. using RmemTiledCopydQacc = decltype(
  508. make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
  509. GmemLayoutAtomdQaccum{},
  510. Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
  511. // using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  512. using SmemCopyAtomPdS = Copy_Atom<
  513. std::conditional_t<!SdP_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  514. Element>;
  515. using SmemCopyAtomdKV = Copy_Atom<
  516. std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  517. Element>;
  518. using SmemCopyAtomdQ = Copy_Atom<
  519. std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  520. Element>;
  521. using SharedStorage = std::conditional_t<
  522. !Is_WS,
  523. SharedStorageQKVdOdKV<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
  524. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV>,
  525. SharedStorageQKVdOdKVWS<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
  526. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV>
  527. // SmemLayoutK, SmemLayoutV, SmemLayoutdS, SmemLayoutdQacc2, SmemLayoutdK, SmemLayoutdV>
  528. >;
  529. // using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages * 2>;
  530. // using PipelineState = typename cutlass::PipelineState<kStages * 2>;
  531. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  532. };
  533. ////////////////////////////////////////////////////////////////////////////////////////////////////
  534. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
  535. bool SdP_swapAB_, bool dKV_swapAB_, bool dQ_swapAB_,
  536. int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
  537. int kClusterN_ = 1, typename elem_type=cutlass::half_t>
  538. struct Flash_bwd_seqqpar_kernel_traits {
  539. using Element = elem_type;
  540. using ElementAccum = float;
  541. using index_t = int64_t;
  542. // The number of threads.
  543. static constexpr int kNWarps = kNWarps_;
  544. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  545. static_assert(kNWarps_ == 8);
  546. static constexpr int kBlockM = kBlockM_;
  547. static constexpr int kBlockN = kBlockN_;
  548. static constexpr int kHeadDim = kHeadDim_;
  549. static_assert(kHeadDim % 32 == 0);
  550. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  551. static constexpr int kClusterN = kClusterN_;
  552. using ClusterShape_MNK = Shape<_1, Int<kClusterN>, _1>;
  553. static constexpr int kStages = 2;
  554. static constexpr bool SdP_swapAB = SdP_swapAB_;
  555. static constexpr bool dKV_swapAB = dKV_swapAB_;
  556. static constexpr bool dQ_swapAB = dQ_swapAB_;
  557. static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV
  558. static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS
  559. using TileShapeAtomSdP = std::conditional_t<
  560. !SdP_swapAB,
  561. Shape<Int<kBlockM>, Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kHeadDim>>,
  562. Shape<Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kBlockM>, Int<kHeadDim>>
  563. >;
  564. using AtomLayoutSdP = std::conditional_t<
  565. !SdP_swapAB,
  566. Layout<Shape<Int<AtomLayoutMSdP>, Int<2 / AtomLayoutMSdP>, _1>>,
  567. Layout<Shape<Int<2 / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>
  568. >;
  569. using TiledMmaSdP = decltype(cute::make_tiled_mma(
  570. cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
  571. AtomLayoutSdP{}));
  572. using TileShapeAtomdKV = std::conditional_t<
  573. !dKV_swapAB,
  574. Shape<Int<kBlockN>, Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockM>>,
  575. Shape<Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockN>, Int<kBlockM>>
  576. >;
  577. using AtomLayoutdKV = std::conditional_t<
  578. !dKV_swapAB,
  579. Layout<Shape<Int<AtomLayoutNdKV>, Int<2 / AtomLayoutNdKV>, _1>>,
  580. Layout<Shape<Int<2 / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>
  581. >;
  582. using TiledMmadKV = decltype(cute::make_tiled_mma(
  583. std::conditional_t<
  584. !SdP_swapAB,
  585. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::MN, GMMA::Major::MN>()),
  586. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::K, GMMA::Major::MN>())
  587. >{},
  588. AtomLayoutdKV{}));
  589. using TileShapeAtomdQ = std::conditional_t<
  590. !dQ_swapAB,
  591. Shape<Int<kBlockM>, Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockN>>,
  592. Shape<Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockM>, Int<kBlockN>>
  593. >;
  594. using AtomLayoutdQ = std::conditional_t<
  595. !dQ_swapAB,
  596. Layout<Shape<Int<AtomLayoutMdQ>, Int<2 / AtomLayoutMdQ>, _1>>,
  597. Layout<Shape<Int<2 / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>
  598. >;
  599. static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN;
  600. static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K;
  601. using TiledMmadQ = decltype(cute::make_tiled_mma(
  602. std::conditional_t<
  603. !dQ_swapAB,
  604. std::conditional_t<
  605. Mma_dQ_is_RS,
  606. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>()),
  607. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>())
  608. >,
  609. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::MN, GMMA::Major::K>())
  610. >{},
  611. AtomLayoutdQ{}));
  612. using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
  613. using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
  614. using GmemTiledCopydKV = cute::SM90_TMA_STORE;
  615. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  616. static constexpr bool Has_cp_async = true;
  617. #else
  618. static constexpr bool Has_cp_async = false;
  619. #endif
  620. // For the dot_do_o preprocessing kernel
  621. using Gmem_copy_struct = std::conditional_t<
  622. Has_cp_async,
  623. SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
  624. DefaultCopy
  625. >;
  626. static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
  627. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  628. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
  629. // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
  630. // to affect speed in practice.
  631. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
  632. static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
  633. using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  634. Stride<Int<kGmemThreadsPerRow>, _1>>;
  635. using GmemTiledCopydO = decltype(
  636. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  637. GmemLayoutAtom{},
  638. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  639. using GmemTiledCopydQ = decltype(
  640. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  641. GmemLayoutAtom{},
  642. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  643. using GmemLayoutAtomdQaccum = std::conditional_t<
  644. kBlockKSmem == 32,
  645. Layout<Shape <_32, _8>, // Thread layout, 8 threads per row
  646. Stride< _8, _1>>,
  647. Layout<Shape <_16, _16>, // Thread layout, 16 threads per row
  648. Stride< _16, _1>>
  649. >;
  650. using GmemTiledCopydQaccum = decltype(
  651. make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
  652. GmemLayoutAtomdQaccum{},
  653. Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
  654. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  655. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  656. using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
  657. using SmemLayoutdO = SmemLayoutQ;
  658. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  659. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  660. using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{},
  661. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  662. using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  663. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  664. using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{},
  665. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  666. using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  667. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  668. using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{})));
  669. using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  670. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  671. using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, select<0, 1>(TileShape_MNK{})));
  672. // Note this is the transpose in terms of the view, not in terms of memory.
  673. using SmemLayoutQt =
  674. decltype(cute::composition(SmemLayoutQ{},
  675. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  676. make_stride(Int<kBlockM>{}, _1{}))));
  677. using SmemLayoutdOt =
  678. decltype(cute::composition(SmemLayoutdO{},
  679. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  680. make_stride(Int<kBlockM>{}, _1{}))));
  681. using SmemLayoutKt =
  682. decltype(cute::composition(SmemLayoutK{},
  683. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
  684. make_stride(Int<kBlockN>{}, _1{}, Int<kBlockN * kHeadDim>{}))));
  685. using SmemLayoutPt =
  686. decltype(cute::composition(SmemLayoutP{},
  687. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  688. make_stride(Int<kBlockM>{}, _1{}))));
  689. using SmemLayoutdSt =
  690. decltype(cute::composition(SmemLayoutdS{},
  691. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  692. make_stride(Int<kBlockM>{}, _1{}))));
  693. using SmemLayoutdK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
  694. using SmemLayoutdV = SmemLayoutdK;
  695. using SmemLayoutdKt = SmemLayoutKt;
  696. using SmemLayoutdVt = SmemLayoutKt;
  697. using SmemLayoutdQTMA = decltype(tile_to_shape(SmemLayoutAtomK{}, select<0, 2>(TileShape_MNK{})));
  698. static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
  699. using SmemLayoutAtomdQ = decltype(
  700. composition(Swizzle<kSwizzle, 3, 3>{},
  701. Layout<Shape<_8, Int<kBlockKSmem>>,
  702. Stride<Int<kBlockKSmem>, _1>>{}));
  703. using SmemLayoutdQ = decltype(tile_to_shape(
  704. SmemLayoutAtomdQ{},
  705. make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
  706. using SmemLayoutdQt =
  707. decltype(cute::composition(SmemLayoutdQ{},
  708. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  709. make_stride(Int<kBlockM>{}, _1{}))));
  710. static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
  711. using SmemLayoutAtomdKV = decltype(
  712. composition(Swizzle<kSwizzle, 3, 3>{},
  713. Layout<Shape<_8, Int<kBlockKSmem>>,
  714. Stride<Int<kBlockKSmem>, _1>>{}));
  715. using SmemLayoutdKV = decltype(tile_to_shape(
  716. SmemLayoutAtomdKV{},
  717. make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
  718. using SmemLayoutdKVt =
  719. decltype(cute::composition(SmemLayoutdKV{},
  720. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  721. make_stride(Int<kBlockN>{}, _1{}))));
  722. static constexpr int kSmemdKVSize = size(SmemLayoutdKV{}) * sizeof(Element) * 2;
  723. // using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  724. using SmemCopyAtomPdS = Copy_Atom<
  725. std::conditional_t<!SdP_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  726. Element>;
  727. using SmemCopyAtomdKV = Copy_Atom<
  728. std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  729. Element>;
  730. using SmemCopyAtomdQ = Copy_Atom<
  731. std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  732. Element>;
  733. using SharedStorage = SharedStorageQKVdOdKVSeqqPar<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
  734. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQTMA>;
  735. // using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages * 2>;
  736. // using PipelineState = typename cutlass::PipelineState<kStages * 2>;
  737. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  738. };
  739. ////////////////////////////////////////////////////////////////////////////////////////////////////