1
0

kernel_traits.h 57 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/algorithm/copy.hpp"
  6. #include "cute/atom/mma_atom.hpp"
  7. #include "cutlass/gemm/collective/collective_builder.hpp"
  8. #include "cutlass/cutlass.h"
  9. #include "cutlass/layout/layout.h"
  10. #include "cutlass/numeric_types.h"
  11. #include "cutlass/pipeline/pipeline.hpp"
  12. using namespace cute;
  13. template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
  14. class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
  15. struct SharedStorageQKVO {
  16. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
  17. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
  18. union {
  19. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
  20. cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
  21. };
  22. struct {
  23. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  24. cutlass::arch::ClusterBarrier barrier_O;
  25. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  26. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  27. int tile_count_semaphore;
  28. };
  29. };
  30. // Use if Oaccum is too large for SharedStorageQKVO
  31. template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
  32. class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
  33. struct SharedStorageQKVOaccum {
  34. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
  35. union {
  36. struct {
  37. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
  38. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
  39. };
  40. cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
  41. };
  42. struct {
  43. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  44. cutlass::arch::ClusterBarrier barrier_O;
  45. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  46. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  47. int tile_count_semaphore;
  48. };
  49. };
  50. // SharedStorage struct with no smem for O
  51. template <int kStages, class Gemm1Type, class Gemm2Type, class SmemLayoutQ,
  52. class SmemLayoutK, class SmemLayoutV>
  53. struct SharedStorageQKV {
  54. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
  55. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
  56. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
  57. struct {
  58. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  59. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  60. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  61. int tile_count_semaphore;
  62. };
  63. };
  64. template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
  65. class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
  66. struct SharedStorageQKVOVt {
  67. struct {
  68. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
  69. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
  70. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
  71. union {
  72. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v_out;
  73. cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
  74. };
  75. };
  76. struct {
  77. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  78. cutlass::arch::ClusterBarrier barrier_O;
  79. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  80. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  81. typename cutlass::PipelineAsync<kStages>::SharedStorage pipeline_vt;
  82. int tile_count_semaphore;
  83. float softmax_scale_qk_log2;
  84. float descale_v;
  85. bool seqlen_init_k;
  86. };
  87. };
  88. // Use if Oaccum is too large for SharedStorageQKVOVt
  89. template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
  90. class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
  91. struct SharedStorageQKVOVtaccum {
  92. struct {
  93. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
  94. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
  95. union {
  96. struct {
  97. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
  98. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v_out;
  99. };
  100. cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
  101. };
  102. };
  103. struct {
  104. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  105. cutlass::arch::ClusterBarrier barrier_O;
  106. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  107. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  108. typename cutlass::PipelineAsync<kStages>::SharedStorage pipeline_vt;
  109. int tile_count_semaphore;
  110. float softmax_scale_qk_log2;
  111. float descale_v;
  112. bool seqlen_init_k;
  113. };
  114. };
  115. template <int kStages, class Gemm1Type, class Gemm2Type, class SmemLayoutQ,
  116. class SmemLayoutK, class SmemLayoutV>
  117. struct SharedStorageQKVVt {
  118. struct {
  119. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
  120. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
  121. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
  122. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v_out;
  123. };
  124. struct {
  125. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  126. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  127. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  128. typename cutlass::PipelineAsync<kStages>::SharedStorage pipeline_vt;
  129. int tile_count_semaphore;
  130. float softmax_scale_qk_log2;
  131. float descale_v;
  132. bool seqlen_init_k;
  133. };
  134. };
  135. template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
  136. class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
  137. struct SharedStorageQKVOVt_nounion {
  138. struct {
  139. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
  140. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
  141. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
  142. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v_out;
  143. cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
  144. };
  145. struct {
  146. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  147. cutlass::arch::ClusterBarrier barrier_O;
  148. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  149. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  150. typename cutlass::PipelineAsync<kStages>::SharedStorage pipeline_vt;
  151. int tile_count_semaphore;
  152. float softmax_scale_qk_log2;
  153. float descale_v;
  154. bool seqlen_init_k;
  155. };
  156. };
  157. // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
  158. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool Is_Q_in_regs_=false,
  159. int kClusterM_ = 1, typename elem_type=cutlass::half_t, bool Is_split_=false, int kBlockH_ = 1>
  160. struct Flash_fwd_kernel_traits {
  161. using Element = elem_type;
  162. using ElementAccum = float;
  163. using FinalOutputType = elem_type;
  164. using OutputType = std::conditional_t<Is_split_, float, FinalOutputType>;
  165. using index_t = int64_t;
  166. // The number of threads.
  167. static constexpr int kNWarps = kNWarps_;
  168. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  169. static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarp;
  170. static constexpr bool Is_Q_in_regs = Is_Q_in_regs_;
  171. static_assert(kNWarps_ == 8 || kNWarps_ == 12 || kNWarps_ == 16);
  172. static constexpr bool Is_WS = true;
  173. static_assert(!(Is_WS && Is_Q_in_regs), "Warp-specialization does not support Q in registers");
  174. static constexpr int kBlockM = kBlockM_;
  175. static constexpr int kBlockN = kBlockN_;
  176. static constexpr int kBlockH = kBlockH_;
  177. static constexpr int kHeadDim = kHeadDim_;
  178. static_assert(kHeadDim % 32 == 0);
  179. static_assert(kBlockM % kBlockH == 0);
  180. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  181. static constexpr int kClusterM = kClusterM_;
  182. using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
  183. static constexpr int kStages = kStages_;
  184. static constexpr bool Is_split = Is_split_;
  185. static constexpr bool No_smem_O = Is_split;
  186. using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
  187. using TiledMma0 = decltype(cute::make_tiled_mma(
  188. std::conditional_t<
  189. Is_Q_in_regs,
  190. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>()),
  191. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>())
  192. >{},
  193. AtomLayoutMNK{}));
  194. using TiledMma1 = decltype(cute::make_tiled_mma(
  195. cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{})),
  196. GMMA::Major::K, GMMA::Major::MN>(),
  197. AtomLayoutMNK{}));
  198. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  199. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  200. using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
  201. // for gmem -> smem Q copy
  202. using FactoringLayoutQ = Layout<Shape<Int<kBlockM/kBlockH>, Int<kBlockH>, Int<kHeadDim>>,
  203. Stride<Int<kBlockH>, _1, Int<kBlockM>>>;
  204. using TileShapeQCopy = std::conditional_t<(kBlockH > 1),
  205. decltype(shape(FactoringLayoutQ{})), decltype(select<0, 2>(TileShape_MNK{}))>;
  206. using SmemLayoutQCopy = std::conditional_t<(kBlockH > 1),
  207. decltype(composition(SmemLayoutQ{}, FactoringLayoutQ{})), SmemLayoutQ>;
  208. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  209. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  210. using SmemLayoutK =
  211. decltype(tile_to_shape(SmemLayoutAtomK{},
  212. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  213. using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  214. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  215. using SmemLayoutV =
  216. decltype(tile_to_shape(SmemLayoutAtomV{},
  217. make_shape(get<1>(TileShape_MNK{}), get<2>(TileShape_MNK{}), Int<kStages>{})));
  218. // Note this is the transpose in terms of the view, not in terms of memory.
  219. using SmemLayoutVt =
  220. decltype(composition(SmemLayoutV{},
  221. make_ordered_layout(
  222. make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
  223. Step<_2, _1, _3>{})));
  224. using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, OutputType,
  225. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  226. using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
  227. // for smem -> gmem O copy
  228. using TileShapeOCopy = TileShapeQCopy;
  229. using SmemLayoutOCopy = std::conditional_t<(kBlockH > 1),
  230. decltype(composition(SmemLayoutO{}, FactoringLayoutQ{})), SmemLayoutO>;
  231. using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  232. using SharedStorage = std::conditional_t<!No_smem_O,
  233. SharedStorageQKVO<kStages, Element, Element, OutputType, SmemLayoutQ, SmemLayoutK, SmemLayoutV, SmemLayoutO>,
  234. SharedStorageQKV<kStages, Element, Element, SmemLayoutQ, SmemLayoutK, SmemLayoutV>>;
  235. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  236. using MainloopPipelineNoTMA = typename cutlass::PipelineAsync<kStages>;
  237. using PipelineState = typename cutlass::PipelineState<kStages>;
  238. // using BarrierType = typename MainloopPipeline::ProducerBarrierType;
  239. };
  240. // Traits struct for fp8 kernel with in-kernel transpose
  241. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool Is_Q_in_regs_=false,
  242. int kClusterM_ = 1, typename elem_type=cutlass::float_e4m3_t, bool Is_split_ = false, int kBlockH_ = 1>
  243. struct Flash_fwd_kernel_traits_fp8 {
  244. using Element = elem_type;
  245. static_assert(cutlass::sizeof_bits_v<Element> == 8);
  246. using ElementAccum = float;
  247. using FinalOutputType = cutlass::bfloat16_t;
  248. using OutputType = std::conditional_t<Is_split_, float, FinalOutputType>;
  249. using index_t = int64_t;
  250. static constexpr bool Is_split = Is_split_;
  251. static constexpr bool No_smem_O = false;
  252. // NOTE: not using smem for epilogue degrades perf substantially.
  253. // static constexpr bool No_smem_O = Is_split;
  254. // The number of threads.
  255. static constexpr int kNWarps = kNWarps_;
  256. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  257. static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
  258. static constexpr bool Is_Q_in_regs = Is_Q_in_regs_;
  259. static_assert(kNWarps_ == 8 || kNWarps_ == 12 || kNWarps_ == 16);
  260. static constexpr bool Is_WS = true;
  261. static_assert(!Is_Q_in_regs, "Warp-specialization does not support Q in registers");
  262. static constexpr int kBlockM = kBlockM_;
  263. static constexpr int kBlockN = kBlockN_;
  264. static constexpr int kBlockH = kBlockH_;
  265. static constexpr int kHeadDim = kHeadDim_;
  266. static_assert(kHeadDim % 32 == 0);
  267. static_assert(kBlockM % kBlockH == 0);
  268. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  269. static constexpr int kClusterM = kClusterM_;
  270. using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
  271. static constexpr int kStages = kStages_;
  272. static_assert(kStages > 1);
  273. // Use this to save enough smem when writing out in float precision.
  274. static constexpr bool VO_union_all = Is_split && (kBlockM != 64) && (kHeadDim == 256);
  275. using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
  276. using TiledMma0 = decltype(cute::make_tiled_mma(
  277. cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
  278. AtomLayoutMNK{}));
  279. using TiledMma1 = decltype(cute::make_tiled_mma(
  280. cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{}))>(),
  281. AtomLayoutMNK{}));
  282. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  283. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  284. using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
  285. // for gmem -> smem Q copy
  286. using FactoringLayoutQ = Layout<Shape<Int<kBlockM/kBlockH>, Int<kBlockH>, Int<kHeadDim>>,
  287. Stride<Int<kBlockH>, _1, Int<kBlockM>>>;
  288. using TileShapeQCopy = std::conditional_t<(kBlockH > 1),
  289. decltype(shape(FactoringLayoutQ{})), decltype(select<0, 2>(TileShape_MNK{}))>;
  290. using SmemLayoutQCopy = std::conditional_t<(kBlockH > 1),
  291. decltype(composition(SmemLayoutQ{}, FactoringLayoutQ{})), SmemLayoutQ>;
  292. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  293. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  294. using SmemLayoutK =
  295. decltype(tile_to_shape(SmemLayoutAtomK{},
  296. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  297. using TransposeShapeAtomV = Shape<_64, _64>;
  298. using SmemLayoutAtomV = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
  299. using SmemLayoutV =
  300. decltype(tile_to_shape(SmemLayoutAtomV{},
  301. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  302. // for fp8 in-kernel transpose -- src layout
  303. using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{}));
  304. using SmemShapeLDSM = Shape<Shape<_8, _8>, Shape<_16, _4>>;
  305. using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{},
  306. shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{}), shape<3>(SmemLayoutDivideV{})));
  307. using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{})));
  308. // For fp8, this is the memory transpose.
  309. using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
  310. using SmemLayoutVt =
  311. decltype(tile_to_shape(SmemLayoutAtomVt{},
  312. make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{})));
  313. // for fp8 in-kernel transpose -- dst layout
  314. using SmemLayoutVtTrans =
  315. decltype(composition(SmemLayoutVt{},
  316. make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1, _3>{})));
  317. using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{}));
  318. #ifndef NO_FP8_COLUMN_PERMUTE
  319. using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_8, _8>>;
  320. #else
  321. using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_16, _4>>;
  322. #endif
  323. using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{},
  324. shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{}), shape<3>(SmemLayoutDivideVt{})));
  325. using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{})));
  326. using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, OutputType,
  327. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  328. using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
  329. // for smem -> gmem O copy
  330. using TileShapeOCopy = TileShapeQCopy;
  331. using SmemLayoutOCopy = std::conditional_t<(kBlockH > 1),
  332. decltype(composition(SmemLayoutO{}, FactoringLayoutQ{})), SmemLayoutO>;
  333. // used for rmem -> smem O copy in fp8 kernel to undo column permutation
  334. using ThreadLayoutrO = Layout<Shape<_8, Int<kBlockM/16>, _4, _1>,
  335. Stride<_4, _32, _1, _0>>;
  336. using ValueLayoutrO = Layout<Shape<_1, _2, Shape<_2, _2>, Int<kHeadDim/16>>,
  337. Stride<_0, _2, Stride<_4, _1>, _8>>;
  338. using TiledCopyrO = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<uint16_t>, OutputType>{},
  339. ThreadLayoutrO{}, ValueLayoutrO{}));
  340. using TiledCopyShaperO = Shape<_8, Int<kBlockM/8>, _16, Int<kHeadDim/16>>;
  341. using SmemLayoutrO = decltype(composition(SmemLayoutO{}, Layout<TiledCopyShaperO>{}));
  342. using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  343. using SharedStorage = std::conditional_t<!No_smem_O,
  344. std::conditional_t<!VO_union_all,
  345. SharedStorageQKVOVt<kStages, Element, Element, OutputType, SmemLayoutQ, SmemLayoutK, SmemLayoutV, SmemLayoutO>,
  346. SharedStorageQKVOVtaccum<kStages, Element, Element, OutputType, SmemLayoutQ, SmemLayoutK, SmemLayoutV, SmemLayoutO>>,
  347. SharedStorageQKVVt<kStages, Element, Element, SmemLayoutQ, SmemLayoutK, SmemLayoutV>>;
  348. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  349. using MainloopPipelineNoTMA = typename cutlass::PipelineAsync<kStages>;
  350. using PipelineState = typename cutlass::PipelineState<kStages>;
  351. // using BarrierType = typename MainloopPipeline::ProducerBarrierType;
  352. };
  353. ////////////////////////////////////////////////////////////////////////////////////////////////////
  354. template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  355. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  356. class SmemLayoutdK, class SmemLayoutdV>
  357. struct SharedStorageQKVdOdKV;
  358. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  359. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  360. class SmemLayoutdK, class SmemLayoutdV>
  361. struct SharedStorageQKVdOdKV<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  362. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV> {
  363. struct {
  364. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  365. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  366. union {
  367. struct {
  368. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  369. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  370. };
  371. struct {
  372. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  373. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  374. };
  375. };
  376. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  377. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  378. };
  379. struct {
  380. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  381. cutlass::arch::ClusterTransactionBarrier barrier_K;
  382. cutlass::arch::ClusterTransactionBarrier barrier_V;
  383. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  384. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  385. };
  386. };
  387. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  388. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  389. class SmemLayoutdK, class SmemLayoutdV>
  390. struct SharedStorageQKVdOdKV<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  391. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV> {
  392. struct {
  393. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  394. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  395. union {
  396. struct {
  397. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  398. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  399. };
  400. struct {
  401. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  402. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  403. };
  404. };
  405. union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
  406. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  407. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  408. };
  409. };
  410. struct {
  411. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  412. cutlass::arch::ClusterTransactionBarrier barrier_K;
  413. cutlass::arch::ClusterTransactionBarrier barrier_V;
  414. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  415. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  416. };
  417. };
  418. template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  419. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
  420. class SmemLayoutdK, class SmemLayoutdV>
  421. struct SharedStorageQKVdOdKVWS;
  422. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  423. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
  424. class SmemLayoutdK, class SmemLayoutdV>
  425. struct SharedStorageQKVdOdKVWS<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  426. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV> {
  427. struct {
  428. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  429. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  430. union {
  431. struct {
  432. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  433. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  434. };
  435. struct {
  436. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  437. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  438. };
  439. };
  440. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  441. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  442. cute::array_aligned<float, cute::cosize_v<SmemLayoutdQacc>> smem_dqacc;
  443. cute::array_aligned<float, 128> smem_lse;
  444. cute::array_aligned<float, 128> smem_dpsum;
  445. };
  446. struct {
  447. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  448. cutlass::arch::ClusterTransactionBarrier barrier_K;
  449. cutlass::arch::ClusterTransactionBarrier barrier_V;
  450. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  451. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  452. };
  453. };
  454. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  455. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
  456. class SmemLayoutdK, class SmemLayoutdV>
  457. struct SharedStorageQKVdOdKVWS<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  458. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV> {
  459. struct {
  460. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  461. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  462. union {
  463. struct {
  464. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  465. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  466. };
  467. struct {
  468. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  469. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  470. };
  471. };
  472. union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
  473. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  474. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  475. };
  476. cute::array_aligned<float, cute::cosize_v<SmemLayoutdQacc>> smem_dqacc;
  477. cute::array_aligned<float, 128> smem_lse;
  478. cute::array_aligned<float, 128> smem_dpsum;
  479. };
  480. struct {
  481. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  482. cutlass::arch::ClusterTransactionBarrier barrier_K;
  483. cutlass::arch::ClusterTransactionBarrier barrier_V;
  484. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  485. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  486. };
  487. };
  488. template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  489. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  490. class SmemLayoutdQ>
  491. struct SharedStorageQKVdOdKVSeqqPar;
  492. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  493. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  494. class SmemLayoutdQ>
  495. struct SharedStorageQKVdOdKVSeqqPar<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  496. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQ> {
  497. struct {
  498. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  499. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  500. union {
  501. struct {
  502. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  503. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  504. };
  505. struct {
  506. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
  507. };
  508. };
  509. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  510. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  511. };
  512. struct {
  513. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  514. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  515. cutlass::arch::ClusterTransactionBarrier barrier_dO;
  516. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  517. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  518. };
  519. };
  520. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  521. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  522. class SmemLayoutdQ>
  523. struct SharedStorageQKVdOdKVSeqqPar<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  524. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQ> {
  525. struct {
  526. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  527. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  528. union {
  529. struct {
  530. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  531. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  532. };
  533. struct {
  534. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
  535. };
  536. };
  537. union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
  538. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  539. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  540. };
  541. };
  542. struct {
  543. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  544. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  545. cutlass::arch::ClusterTransactionBarrier barrier_dO;
  546. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  547. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  548. };
  549. };
  550. ////////////////////////////////////////////////////////////////////////////////////////////////////
  551. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
  552. bool SdP_swapAB_, bool dKV_swapAB_, bool dQ_swapAB_,
  553. int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
  554. int kClusterN_ = 1, typename elem_type=cutlass::half_t>
  555. struct Flash_bwd_kernel_traits {
  556. using Element = elem_type;
  557. using ElementAccum = float;
  558. using index_t = int64_t;
  559. // The number of threads.
  560. static constexpr int kNWarps = kNWarps_;
  561. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  562. static constexpr int kNThreadsNonWS = 8 * cutlass::NumThreadsPerWarp;
  563. // static constexpr int kNThreadsdQ = cutlass::NumThreadsPerWarpGroup;
  564. static constexpr int kNThreadsdQ = 2 * cutlass::NumThreadsPerWarpGroup;
  565. static_assert(kNWarps_ == 8 || kNWarps_ == 12);
  566. static constexpr bool Is_WS = kNWarps_ >= 12;
  567. static constexpr int kBlockM = kBlockM_;
  568. static constexpr int kBlockN = kBlockN_;
  569. static constexpr int kHeadDim = kHeadDim_;
  570. static_assert(kHeadDim % 32 == 0);
  571. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  572. static constexpr int kClusterN = kClusterN_;
  573. using ClusterShape_MNK = Shape<_1, Int<kClusterN>, _1>;
  574. static constexpr int kStages = 2;
  575. static constexpr bool SdP_swapAB = SdP_swapAB_;
  576. static constexpr bool dKV_swapAB = dKV_swapAB_;
  577. static constexpr bool dQ_swapAB = dQ_swapAB_;
  578. static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV
  579. static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS
  580. using TileShapeAtomSdP = std::conditional_t<
  581. !SdP_swapAB,
  582. Shape<Int<kBlockM>, Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kHeadDim>>,
  583. Shape<Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kBlockM>, Int<kHeadDim>>
  584. >;
  585. using AtomLayoutSdP = std::conditional_t<
  586. !SdP_swapAB,
  587. Layout<Shape<Int<AtomLayoutMSdP>, Int<2 / AtomLayoutMSdP>, _1>>,
  588. Layout<Shape<Int<2 / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>
  589. >;
  590. using TiledMmaSdP = decltype(cute::make_tiled_mma(
  591. cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
  592. AtomLayoutSdP{}));
  593. using TileShapeAtomdKV = std::conditional_t<
  594. !dKV_swapAB,
  595. Shape<Int<kBlockN>, Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockM>>,
  596. Shape<Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockN>, Int<kBlockM>>
  597. >;
  598. using AtomLayoutdKV = std::conditional_t<
  599. !dKV_swapAB,
  600. Layout<Shape<Int<AtomLayoutNdKV>, Int<2 / AtomLayoutNdKV>, _1>>,
  601. Layout<Shape<Int<2 / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>
  602. >;
  603. using TiledMmadKV = decltype(cute::make_tiled_mma(
  604. std::conditional_t<
  605. !SdP_swapAB,
  606. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::MN, GMMA::Major::MN>()),
  607. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::K, GMMA::Major::MN>())
  608. >{},
  609. AtomLayoutdKV{}));
  610. using TileShapeAtomdQ = std::conditional_t<
  611. !dQ_swapAB,
  612. Shape<Int<kBlockM>, Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockN>>,
  613. Shape<Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockM>, Int<kBlockN>>
  614. // Shape<Int<kBlockM>, Int<kHeadDim >, Int<kBlockN>>,
  615. // Shape<Int<kHeadDim>, Int<kBlockM>, Int<kBlockN>>
  616. >;
  617. using AtomLayoutdQ = std::conditional_t<
  618. !dQ_swapAB,
  619. Layout<Shape<Int<AtomLayoutMdQ>, Int<2 / AtomLayoutMdQ>, _1>>,
  620. Layout<Shape<Int<2 / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>
  621. // Layout<Shape<Int<1>, Int<1>, _1>>,
  622. // Layout<Shape<Int<1>, Int<1>, _1>>
  623. >;
  624. static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN;
  625. static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K;
  626. using TiledMmadQ = decltype(cute::make_tiled_mma(
  627. std::conditional_t<
  628. !dQ_swapAB,
  629. std::conditional_t<
  630. Mma_dQ_is_RS,
  631. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>()),
  632. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>())
  633. >,
  634. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::MN, GMMA::Major::K>())
  635. >{},
  636. AtomLayoutdQ{}));
  637. using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
  638. using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
  639. using GmemTiledCopydKV = cute::SM90_TMA_STORE;
  640. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  641. static constexpr bool Has_cp_async = true;
  642. #else
  643. static constexpr bool Has_cp_async = false;
  644. #endif
  645. // For the dot_do_o preprocessing kernel
  646. using Gmem_copy_struct = std::conditional_t<
  647. Has_cp_async,
  648. SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
  649. DefaultCopy
  650. >;
  651. static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
  652. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  653. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
  654. // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
  655. // to affect speed in practice.
  656. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
  657. static_assert(kNThreadsNonWS % kGmemThreadsPerRow == 0, "kNThreadsNonWS must be a multiple of kGmemThreadsPerRow");
  658. using GmemLayoutAtom = Layout<Shape <Int<kNThreadsNonWS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  659. Stride<Int<kGmemThreadsPerRow>, _1>>;
  660. using GmemLayoutAtomdQ = Layout<Shape <Int<kNThreadsdQ / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  661. Stride<Int<kGmemThreadsPerRow>, _1>>;
  662. using GmemTiledCopydO = decltype(
  663. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  664. GmemLayoutAtom{},
  665. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  666. using GmemTiledCopydQ = decltype(
  667. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  668. GmemLayoutAtomdQ{},
  669. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  670. using GmemLayoutAtomdQaccum = std::conditional_t<
  671. kBlockKSmem == 32,
  672. Layout<Shape <Int<kNThreadsdQ / 8>, _8>, // Thread layout, 8 threads per row
  673. Stride< _8, _1>>,
  674. Layout<Shape <Int<kNThreadsdQ / 16>, _16>, // Thread layout, 16 threads per row
  675. Stride< _16, _1>>
  676. >;
  677. using GmemTiledCopydQaccum = decltype(
  678. make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
  679. GmemLayoutAtomdQaccum{},
  680. Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
  681. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  682. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  683. using SmemLayoutQ =
  684. decltype(tile_to_shape(SmemLayoutAtomQ{},
  685. make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  686. using SmemLayoutdO = SmemLayoutQ;
  687. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  688. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  689. using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
  690. using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  691. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  692. using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{})));
  693. using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  694. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  695. using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{})));
  696. using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  697. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  698. using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, select<0, 1>(TileShape_MNK{})));
  699. // using SmemLayoutAtomdQacc = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, ElementAccum,
  700. // decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  701. // using SmemLayoutdQacc = decltype(tile_to_shape(SmemLayoutAtomdQacc{}, select<0, 2>(TileShape_MNK{})));
  702. // Note this is the transpose in terms of the view, not in terms of memory.
  703. using SmemLayoutQt =
  704. decltype(cute::composition(SmemLayoutQ{},
  705. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
  706. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
  707. using SmemLayoutdOt =
  708. decltype(cute::composition(SmemLayoutdO{},
  709. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
  710. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
  711. using SmemLayoutKt =
  712. decltype(cute::composition(SmemLayoutK{},
  713. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  714. make_stride(Int<kBlockN>{}, _1{}))));
  715. using SmemLayoutPt =
  716. decltype(cute::composition(SmemLayoutP{},
  717. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  718. make_stride(Int<kBlockM>{}, _1{}))));
  719. using SmemLayoutdSt =
  720. decltype(cute::composition(SmemLayoutdS{},
  721. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  722. make_stride(Int<kBlockM>{}, _1{}))));
  723. // using SmemLayoutdQacct =
  724. // decltype(cute::composition(SmemLayoutdQacc{},
  725. // make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  726. // make_stride(Int<kBlockM>{}, _1{}))));
  727. using SmemLayoutdK = SmemLayoutK;
  728. using SmemLayoutdV = SmemLayoutV;
  729. using SmemLayoutdKt = SmemLayoutKt;
  730. using SmemLayoutdVt = SmemLayoutKt;
  731. static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
  732. using SmemLayoutAtomdQ = decltype(
  733. // composition(Swizzle<kSwizzle, 3, 3>{},
  734. composition(Swizzle<3, 3, 3>{},
  735. Layout<Shape<Int<kNThreadsdQ / 32>, Int<32>>,
  736. Stride<Int<32>, _1>>{}));
  737. using SmemLayoutdQ = decltype(tile_to_shape(
  738. SmemLayoutAtomdQ{},
  739. make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
  740. using SmemLayoutdQt =
  741. decltype(cute::composition(SmemLayoutdQ{},
  742. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  743. make_stride(Int<kBlockM>{}, _1{}))));
  744. static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
  745. using SmemLayoutAtomdQaccTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, ElementAccum,
  746. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  747. using SmemLayoutdQaccTMA = decltype(tile_to_shape(SmemLayoutAtomdQaccTMA{}, select<0, 2>(TileShape_MNK{})));
  748. using SmemLayoutdQacc = SmemLayoutdQ;
  749. using SmemLayoutdQacct = SmemLayoutdQt;
  750. using SmemLayoutdQacc2 = decltype(tile_to_shape(
  751. SmemLayoutAtomdQ{},
  752. make_shape(Int<kBlockM>{}, Int<kHeadDim>{}, _2{})));
  753. // using SmemLayoutdQacc = decltype(tile_to_shape(SmemLayoutAtomdQacc{}, select<0, 2>(TileShape_MNK{})));
  754. // using SmemLayoutdQacct =
  755. // decltype(cute::composition(SmemLayoutdQacc{},
  756. // make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  757. // make_stride(Int<kBlockM>{}, _1{}))));
  758. using RmemTiledCopydQacc = decltype(
  759. make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
  760. GmemLayoutAtomdQaccum{},
  761. Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
  762. // using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  763. using SmemCopyAtomPdS = Copy_Atom<
  764. std::conditional_t<!SdP_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  765. Element>;
  766. using SmemCopyAtomdKV = Copy_Atom<
  767. std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  768. Element>;
  769. using SmemCopyAtomdQ = Copy_Atom<
  770. std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  771. Element>;
  772. using SharedStorage = std::conditional_t<
  773. !Is_WS,
  774. SharedStorageQKVdOdKV<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
  775. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV>,
  776. SharedStorageQKVdOdKVWS<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
  777. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV>
  778. // SmemLayoutK, SmemLayoutV, SmemLayoutdS, SmemLayoutdQacc2, SmemLayoutdK, SmemLayoutdV>
  779. >;
  780. // using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages * 2>;
  781. // using PipelineState = typename cutlass::PipelineState<kStages * 2>;
  782. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  783. };
  784. ////////////////////////////////////////////////////////////////////////////////////////////////////
  785. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
  786. bool SdP_swapAB_, bool dKV_swapAB_, bool dQ_swapAB_,
  787. int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
  788. int kClusterN_ = 1, typename elem_type=cutlass::half_t>
  789. struct Flash_bwd_seqqpar_kernel_traits {
  790. using Element = elem_type;
  791. using ElementAccum = float;
  792. using index_t = int64_t;
  793. // The number of threads.
  794. static constexpr int kNWarps = kNWarps_;
  795. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  796. static_assert(kNWarps_ == 8);
  797. static constexpr int kBlockM = kBlockM_;
  798. static constexpr int kBlockN = kBlockN_;
  799. static constexpr int kHeadDim = kHeadDim_;
  800. static_assert(kHeadDim % 32 == 0);
  801. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  802. static constexpr int kClusterN = kClusterN_;
  803. using ClusterShape_MNK = Shape<_1, Int<kClusterN>, _1>;
  804. static constexpr int kStages = 2;
  805. static constexpr bool SdP_swapAB = SdP_swapAB_;
  806. static constexpr bool dKV_swapAB = dKV_swapAB_;
  807. static constexpr bool dQ_swapAB = dQ_swapAB_;
  808. static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV
  809. static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS
  810. using TileShapeAtomSdP = std::conditional_t<
  811. !SdP_swapAB,
  812. Shape<Int<kBlockM>, Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kHeadDim>>,
  813. Shape<Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kBlockM>, Int<kHeadDim>>
  814. >;
  815. using AtomLayoutSdP = std::conditional_t<
  816. !SdP_swapAB,
  817. Layout<Shape<Int<AtomLayoutMSdP>, Int<2 / AtomLayoutMSdP>, _1>>,
  818. Layout<Shape<Int<2 / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>
  819. >;
  820. using TiledMmaSdP = decltype(cute::make_tiled_mma(
  821. cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
  822. AtomLayoutSdP{}));
  823. using TileShapeAtomdKV = std::conditional_t<
  824. !dKV_swapAB,
  825. Shape<Int<kBlockN>, Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockM>>,
  826. Shape<Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockN>, Int<kBlockM>>
  827. >;
  828. using AtomLayoutdKV = std::conditional_t<
  829. !dKV_swapAB,
  830. Layout<Shape<Int<AtomLayoutNdKV>, Int<2 / AtomLayoutNdKV>, _1>>,
  831. Layout<Shape<Int<2 / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>
  832. >;
  833. using TiledMmadKV = decltype(cute::make_tiled_mma(
  834. std::conditional_t<
  835. !SdP_swapAB,
  836. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::MN, GMMA::Major::MN>()),
  837. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::K, GMMA::Major::MN>())
  838. >{},
  839. AtomLayoutdKV{}));
  840. using TileShapeAtomdQ = std::conditional_t<
  841. !dQ_swapAB,
  842. Shape<Int<kBlockM>, Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockN>>,
  843. Shape<Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockM>, Int<kBlockN>>
  844. >;
  845. using AtomLayoutdQ = std::conditional_t<
  846. !dQ_swapAB,
  847. Layout<Shape<Int<AtomLayoutMdQ>, Int<2 / AtomLayoutMdQ>, _1>>,
  848. Layout<Shape<Int<2 / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>
  849. >;
  850. static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN;
  851. static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K;
  852. using TiledMmadQ = decltype(cute::make_tiled_mma(
  853. std::conditional_t<
  854. !dQ_swapAB,
  855. std::conditional_t<
  856. Mma_dQ_is_RS,
  857. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>()),
  858. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>())
  859. >,
  860. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::MN, GMMA::Major::K>())
  861. >{},
  862. AtomLayoutdQ{}));
  863. using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
  864. using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
  865. using GmemTiledCopydKV = cute::SM90_TMA_STORE;
  866. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  867. static constexpr bool Has_cp_async = true;
  868. #else
  869. static constexpr bool Has_cp_async = false;
  870. #endif
  871. // For the dot_do_o preprocessing kernel
  872. using Gmem_copy_struct = std::conditional_t<
  873. Has_cp_async,
  874. SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
  875. DefaultCopy
  876. >;
  877. static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
  878. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  879. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
  880. // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
  881. // to affect speed in practice.
  882. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
  883. static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
  884. using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  885. Stride<Int<kGmemThreadsPerRow>, _1>>;
  886. using GmemTiledCopydO = decltype(
  887. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  888. GmemLayoutAtom{},
  889. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  890. using GmemTiledCopydQ = decltype(
  891. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  892. GmemLayoutAtom{},
  893. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  894. using GmemLayoutAtomdQaccum = std::conditional_t<
  895. kBlockKSmem == 32,
  896. Layout<Shape <_32, _8>, // Thread layout, 8 threads per row
  897. Stride< _8, _1>>,
  898. Layout<Shape <_16, _16>, // Thread layout, 16 threads per row
  899. Stride< _16, _1>>
  900. >;
  901. using GmemTiledCopydQaccum = decltype(
  902. make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
  903. GmemLayoutAtomdQaccum{},
  904. Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
  905. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  906. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  907. using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
  908. using SmemLayoutdO = SmemLayoutQ;
  909. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  910. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  911. using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{},
  912. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  913. using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  914. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  915. using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{},
  916. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  917. using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  918. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  919. using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{})));
  920. using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  921. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  922. using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, select<0, 1>(TileShape_MNK{})));
  923. // Note this is the transpose in terms of the view, not in terms of memory.
  924. using SmemLayoutQt =
  925. decltype(cute::composition(SmemLayoutQ{},
  926. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  927. make_stride(Int<kBlockM>{}, _1{}))));
  928. using SmemLayoutdOt =
  929. decltype(cute::composition(SmemLayoutdO{},
  930. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  931. make_stride(Int<kBlockM>{}, _1{}))));
  932. using SmemLayoutKt =
  933. decltype(cute::composition(SmemLayoutK{},
  934. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
  935. make_stride(Int<kBlockN>{}, _1{}, Int<kBlockN * kHeadDim>{}))));
  936. using SmemLayoutPt =
  937. decltype(cute::composition(SmemLayoutP{},
  938. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  939. make_stride(Int<kBlockM>{}, _1{}))));
  940. using SmemLayoutdSt =
  941. decltype(cute::composition(SmemLayoutdS{},
  942. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  943. make_stride(Int<kBlockM>{}, _1{}))));
  944. using SmemLayoutdK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
  945. using SmemLayoutdV = SmemLayoutdK;
  946. using SmemLayoutdKt = SmemLayoutKt;
  947. using SmemLayoutdVt = SmemLayoutKt;
  948. using SmemLayoutdQTMA = decltype(tile_to_shape(SmemLayoutAtomK{}, select<0, 2>(TileShape_MNK{})));
  949. static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
  950. using SmemLayoutAtomdQ = decltype(
  951. composition(Swizzle<kSwizzle, 3, 3>{},
  952. Layout<Shape<_8, Int<kBlockKSmem>>,
  953. Stride<Int<kBlockKSmem>, _1>>{}));
  954. using SmemLayoutdQ = decltype(tile_to_shape(
  955. SmemLayoutAtomdQ{},
  956. make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
  957. using SmemLayoutdQt =
  958. decltype(cute::composition(SmemLayoutdQ{},
  959. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  960. make_stride(Int<kBlockM>{}, _1{}))));
  961. static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
  962. using SmemLayoutAtomdKV = decltype(
  963. composition(Swizzle<kSwizzle, 3, 3>{},
  964. Layout<Shape<_8, Int<kBlockKSmem>>,
  965. Stride<Int<kBlockKSmem>, _1>>{}));
  966. using SmemLayoutdKV = decltype(tile_to_shape(
  967. SmemLayoutAtomdKV{},
  968. make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
  969. using SmemLayoutdKVt =
  970. decltype(cute::composition(SmemLayoutdKV{},
  971. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  972. make_stride(Int<kBlockN>{}, _1{}))));
  973. static constexpr int kSmemdKVSize = size(SmemLayoutdKV{}) * sizeof(Element) * 2;
  974. // using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  975. using SmemCopyAtomPdS = Copy_Atom<
  976. std::conditional_t<!SdP_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  977. Element>;
  978. using SmemCopyAtomdKV = Copy_Atom<
  979. std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  980. Element>;
  981. using SmemCopyAtomdQ = Copy_Atom<
  982. std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  983. Element>;
  984. using SharedStorage = SharedStorageQKVdOdKVSeqqPar<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
  985. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQTMA>;
  986. // using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages * 2>;
  987. // using PipelineState = typename cutlass::PipelineState<kStages * 2>;
  988. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  989. };
  990. ////////////////////////////////////////////////////////////////////////////////////////////////////