kernel_traits.h 56 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/algorithm/copy.hpp"
  6. #include "cute/atom/mma_atom.hpp"
  7. #include "cutlass/gemm/collective/collective_builder.hpp"
  8. #include "cutlass/cutlass.h"
  9. #include "cutlass/layout/layout.h"
  10. #include "cutlass/numeric_types.h"
  11. #include "cutlass/pipeline/pipeline.hpp"
  12. using namespace cute;
  13. template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
  14. class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
  15. struct SharedStorageQKVO {
  16. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
  17. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
  18. union {
  19. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
  20. cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
  21. };
  22. struct {
  23. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  24. cutlass::arch::ClusterBarrier barrier_O;
  25. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  26. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  27. int tile_count_semaphore;
  28. };
  29. };
  30. // Use if Oaccum is too large for SharedStorageQKVO
  31. template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
  32. class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
  33. struct SharedStorageQKVOaccum {
  34. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
  35. union {
  36. struct {
  37. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
  38. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
  39. };
  40. cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
  41. };
  42. struct {
  43. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  44. cutlass::arch::ClusterBarrier barrier_O;
  45. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  46. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  47. int tile_count_semaphore;
  48. };
  49. };
  50. // SharedStorage struct with no smem for O
  51. template <int kStages, class Gemm1Type, class Gemm2Type, class SmemLayoutQ,
  52. class SmemLayoutK, class SmemLayoutV>
  53. struct SharedStorageQKV {
  54. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
  55. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
  56. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
  57. struct {
  58. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  59. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  60. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  61. int tile_count_semaphore;
  62. };
  63. };
  64. template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
  65. class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
  66. struct SharedStorageQKVOVt {
  67. struct {
  68. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
  69. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
  70. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
  71. union {
  72. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v_out;
  73. cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
  74. };
  75. };
  76. struct {
  77. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  78. cutlass::arch::ClusterBarrier barrier_O;
  79. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  80. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  81. typename cutlass::PipelineAsync<kStages>::SharedStorage pipeline_vt;
  82. int tile_count_semaphore;
  83. float softmax_scale_qk_log2;
  84. float descale_v;
  85. bool seqlen_init_k;
  86. };
  87. };
  88. // Use if Oaccum is too large for SharedStorageQKVOVt
  89. template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
  90. class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
  91. struct SharedStorageQKVOVtaccum {
  92. struct {
  93. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
  94. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
  95. union {
  96. struct {
  97. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
  98. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v_out;
  99. };
  100. cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
  101. };
  102. };
  103. struct {
  104. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  105. cutlass::arch::ClusterBarrier barrier_O;
  106. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  107. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  108. typename cutlass::PipelineAsync<kStages>::SharedStorage pipeline_vt;
  109. int tile_count_semaphore;
  110. float softmax_scale_qk_log2;
  111. float descale_v;
  112. bool seqlen_init_k;
  113. };
  114. };
  115. template <int kStages, class Gemm1Type, class Gemm2Type, class SmemLayoutQ,
  116. class SmemLayoutK, class SmemLayoutV>
  117. struct SharedStorageQKVVt {
  118. struct {
  119. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
  120. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
  121. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
  122. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v_out;
  123. };
  124. struct {
  125. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  126. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  127. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  128. typename cutlass::PipelineAsync<kStages>::SharedStorage pipeline_vt;
  129. int tile_count_semaphore;
  130. float softmax_scale_qk_log2;
  131. float descale_v;
  132. bool seqlen_init_k;
  133. };
  134. };
  135. // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
  136. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool Is_Q_in_regs_=false,
  137. int kClusterM_ = 1, typename elem_type=cutlass::half_t, bool Is_split_=false, int kBlockH_ = 1>
  138. struct Flash_fwd_kernel_traits {
  139. using Element = elem_type;
  140. using ElementAccum = float;
  141. using FinalOutputType = elem_type;
  142. using OutputType = std::conditional_t<Is_split_, float, FinalOutputType>;
  143. using index_t = int64_t;
  144. // The number of threads.
  145. static constexpr int kNWarps = kNWarps_;
  146. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  147. static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarp;
  148. static constexpr bool Is_Q_in_regs = Is_Q_in_regs_;
  149. static_assert(kNWarps_ == 8 || kNWarps_ == 12 || kNWarps_ == 16);
  150. static constexpr bool Is_WS = true;
  151. static_assert(!(Is_WS && Is_Q_in_regs), "Warp-specialization does not support Q in registers");
  152. static constexpr int kBlockM = kBlockM_;
  153. static constexpr int kBlockN = kBlockN_;
  154. static constexpr int kBlockH = kBlockH_;
  155. static constexpr int kHeadDim = kHeadDim_;
  156. static_assert(kHeadDim % 32 == 0);
  157. static_assert(kBlockM % kBlockH == 0);
  158. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  159. static constexpr int kClusterM = kClusterM_;
  160. using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
  161. static constexpr int kStages = kStages_;
  162. static constexpr bool Is_split = Is_split_;
  163. static constexpr bool No_smem_O = Is_split;
  164. using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
  165. using TiledMma0 = decltype(cute::make_tiled_mma(
  166. std::conditional_t<
  167. Is_Q_in_regs,
  168. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>()),
  169. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>())
  170. >{},
  171. AtomLayoutMNK{}));
  172. using TiledMma1 = decltype(cute::make_tiled_mma(
  173. cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{})),
  174. GMMA::Major::K, GMMA::Major::MN>(),
  175. AtomLayoutMNK{}));
  176. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  177. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  178. using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
  179. // for gmem -> smem Q copy
  180. using FactoringLayoutQ = Layout<Shape<Int<kBlockM/kBlockH>, Int<kBlockH>, Int<kHeadDim>>,
  181. Stride<Int<kBlockH>, _1, Int<kBlockM>>>;
  182. using TileShapeQCopy = std::conditional_t<(kBlockH > 1),
  183. decltype(shape(FactoringLayoutQ{})), decltype(select<0, 2>(TileShape_MNK{}))>;
  184. using SmemLayoutQCopy = std::conditional_t<(kBlockH > 1),
  185. decltype(composition(SmemLayoutQ{}, FactoringLayoutQ{})), SmemLayoutQ>;
  186. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  187. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  188. using SmemLayoutK =
  189. decltype(tile_to_shape(SmemLayoutAtomK{},
  190. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  191. using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  192. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  193. using SmemLayoutV =
  194. decltype(tile_to_shape(SmemLayoutAtomV{},
  195. make_shape(get<1>(TileShape_MNK{}), get<2>(TileShape_MNK{}), Int<kStages>{})));
  196. // Note this is the transpose in terms of the view, not in terms of memory.
  197. using SmemLayoutVt =
  198. decltype(composition(SmemLayoutV{},
  199. make_ordered_layout(
  200. make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
  201. Step<_2, _1, _3>{})));
  202. using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, OutputType,
  203. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  204. using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
  205. // for smem -> gmem O copy
  206. using TileShapeOCopy = TileShapeQCopy;
  207. using SmemLayoutOCopy = std::conditional_t<(kBlockH > 1),
  208. decltype(composition(SmemLayoutO{}, FactoringLayoutQ{})), SmemLayoutO>;
  209. using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  210. using SharedStorage = std::conditional_t<!No_smem_O,
  211. SharedStorageQKVO<kStages, Element, Element, OutputType, SmemLayoutQ, SmemLayoutK, SmemLayoutV, SmemLayoutO>,
  212. SharedStorageQKV<kStages, Element, Element, SmemLayoutQ, SmemLayoutK, SmemLayoutV>>;
  213. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  214. using MainloopPipelineNoTMA = typename cutlass::PipelineAsync<kStages>;
  215. using PipelineState = typename cutlass::PipelineState<kStages>;
  216. // using BarrierType = typename MainloopPipeline::ProducerBarrierType;
  217. };
  218. // Traits struct for fp8 kernel with in-kernel transpose
  219. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool Is_Q_in_regs_=false,
  220. int kClusterM_ = 1, typename elem_type=cutlass::float_e4m3_t, bool Is_split_ = false, int kBlockH_ = 1>
  221. struct Flash_fwd_kernel_traits_fp8 {
  222. using Element = elem_type;
  223. static_assert(cutlass::sizeof_bits_v<Element> == 8);
  224. using ElementAccum = float;
  225. using FinalOutputType = cutlass::bfloat16_t;
  226. using OutputType = std::conditional_t<Is_split_, float, FinalOutputType>;
  227. using index_t = int64_t;
  228. static constexpr bool Is_split = Is_split_;
  229. static constexpr bool No_smem_O = false;
  230. // NOTE: not using smem for epilogue degrades perf substantially.
  231. // static constexpr bool No_smem_O = Is_split;
  232. // The number of threads.
  233. static constexpr int kNWarps = kNWarps_;
  234. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  235. static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
  236. static constexpr bool Is_Q_in_regs = Is_Q_in_regs_;
  237. static_assert(kNWarps_ == 8 || kNWarps_ == 12 || kNWarps_ == 16);
  238. static constexpr bool Is_WS = true;
  239. static_assert(!Is_Q_in_regs, "Warp-specialization does not support Q in registers");
  240. static constexpr int kBlockM = kBlockM_;
  241. static constexpr int kBlockN = kBlockN_;
  242. static constexpr int kBlockH = kBlockH_;
  243. static constexpr int kHeadDim = kHeadDim_;
  244. static_assert(kHeadDim % 32 == 0);
  245. static_assert(kBlockM % kBlockH == 0);
  246. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  247. static constexpr int kClusterM = kClusterM_;
  248. using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
  249. static constexpr int kStages = kStages_;
  250. static_assert(kStages > 1);
  251. // Use this to save enough smem when writing out in float precision.
  252. static constexpr bool VO_union_all = Is_split && (kBlockM != 64) && (kHeadDim == 256);
  253. using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
  254. using TiledMma0 = decltype(cute::make_tiled_mma(
  255. cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
  256. AtomLayoutMNK{}));
  257. using TiledMma1 = decltype(cute::make_tiled_mma(
  258. cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{}))>(),
  259. AtomLayoutMNK{}));
  260. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  261. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  262. using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
  263. // for gmem -> smem Q copy
  264. using FactoringLayoutQ = Layout<Shape<Int<kBlockM/kBlockH>, Int<kBlockH>, Int<kHeadDim>>,
  265. Stride<Int<kBlockH>, _1, Int<kBlockM>>>;
  266. using TileShapeQCopy = std::conditional_t<(kBlockH > 1),
  267. decltype(shape(FactoringLayoutQ{})), decltype(select<0, 2>(TileShape_MNK{}))>;
  268. using SmemLayoutQCopy = std::conditional_t<(kBlockH > 1),
  269. decltype(composition(SmemLayoutQ{}, FactoringLayoutQ{})), SmemLayoutQ>;
  270. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  271. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  272. using SmemLayoutK =
  273. decltype(tile_to_shape(SmemLayoutAtomK{},
  274. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  275. using TransposeShapeAtomV = Shape<_64, _64>;
  276. using SmemLayoutAtomV = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
  277. using SmemLayoutV =
  278. decltype(tile_to_shape(SmemLayoutAtomV{},
  279. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  280. // for fp8 in-kernel transpose -- src layout
  281. using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{}));
  282. using SmemShapeLDSM = Shape<Shape<_8, _8>, Shape<_16, _4>>;
  283. using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{},
  284. shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{}), shape<3>(SmemLayoutDivideV{})));
  285. using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{})));
  286. // For fp8, this is the memory transpose.
  287. using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
  288. using SmemLayoutVt =
  289. decltype(tile_to_shape(SmemLayoutAtomVt{},
  290. make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{})));
  291. // for fp8 in-kernel transpose -- dst layout
  292. using SmemLayoutVtTrans =
  293. decltype(composition(SmemLayoutVt{},
  294. make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1, _3>{})));
  295. using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{}));
  296. #ifndef NO_FP8_COLUMN_PERMUTE
  297. using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_8, _8>>;
  298. #else
  299. using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_16, _4>>;
  300. #endif
  301. using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{},
  302. shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{}), shape<3>(SmemLayoutDivideVt{})));
  303. using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{})));
  304. using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, OutputType,
  305. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  306. using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
  307. // for smem -> gmem O copy
  308. using TileShapeOCopy = TileShapeQCopy;
  309. using SmemLayoutOCopy = std::conditional_t<(kBlockH > 1),
  310. decltype(composition(SmemLayoutO{}, FactoringLayoutQ{})), SmemLayoutO>;
  311. // used for rmem -> smem O copy in fp8 kernel to undo column permutation
  312. using ThreadLayoutrO = Layout<Shape<_8, Int<kBlockM/16>, _4, _1>,
  313. Stride<_4, _32, _1, _0>>;
  314. using ValueLayoutrO = Layout<Shape<_1, _2, Shape<_2, _2>, Int<kHeadDim/16>>,
  315. Stride<_0, _2, Stride<_4, _1>, _8>>;
  316. using TiledCopyrO = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<uint16_t>, OutputType>{},
  317. ThreadLayoutrO{}, ValueLayoutrO{}));
  318. using TiledCopyShaperO = Shape<_8, Int<kBlockM/8>, _16, Int<kHeadDim/16>>;
  319. using SmemLayoutrO = decltype(composition(SmemLayoutO{}, Layout<TiledCopyShaperO>{}));
  320. using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  321. using SharedStorage = std::conditional_t<!No_smem_O,
  322. std::conditional_t<!VO_union_all,
  323. SharedStorageQKVOVt<kStages, Element, Element, OutputType, SmemLayoutQ, SmemLayoutK, SmemLayoutV, SmemLayoutO>,
  324. SharedStorageQKVOVtaccum<kStages, Element, Element, OutputType, SmemLayoutQ, SmemLayoutK, SmemLayoutV, SmemLayoutO>>,
  325. SharedStorageQKVVt<kStages, Element, Element, SmemLayoutQ, SmemLayoutK, SmemLayoutV>>;
  326. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  327. using MainloopPipelineNoTMA = typename cutlass::PipelineAsync<kStages>;
  328. using PipelineState = typename cutlass::PipelineState<kStages>;
  329. // using BarrierType = typename MainloopPipeline::ProducerBarrierType;
  330. };
  331. ////////////////////////////////////////////////////////////////////////////////////////////////////
  332. template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  333. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  334. class SmemLayoutdK, class SmemLayoutdV>
  335. struct SharedStorageQKVdOdKV;
  336. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  337. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  338. class SmemLayoutdK, class SmemLayoutdV>
  339. struct SharedStorageQKVdOdKV<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  340. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV> {
  341. struct {
  342. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  343. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  344. union {
  345. struct {
  346. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  347. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  348. };
  349. struct {
  350. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  351. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  352. };
  353. };
  354. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  355. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  356. };
  357. struct {
  358. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  359. cutlass::arch::ClusterTransactionBarrier barrier_K;
  360. cutlass::arch::ClusterTransactionBarrier barrier_V;
  361. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  362. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  363. };
  364. };
  365. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  366. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  367. class SmemLayoutdK, class SmemLayoutdV>
  368. struct SharedStorageQKVdOdKV<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  369. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV> {
  370. struct {
  371. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  372. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  373. union {
  374. struct {
  375. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  376. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  377. };
  378. struct {
  379. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  380. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  381. };
  382. };
  383. union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
  384. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  385. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  386. };
  387. };
  388. struct {
  389. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  390. cutlass::arch::ClusterTransactionBarrier barrier_K;
  391. cutlass::arch::ClusterTransactionBarrier barrier_V;
  392. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  393. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  394. };
  395. };
  396. template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  397. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
  398. class SmemLayoutdK, class SmemLayoutdV>
  399. struct SharedStorageQKVdOdKVWS;
  400. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  401. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
  402. class SmemLayoutdK, class SmemLayoutdV>
  403. struct SharedStorageQKVdOdKVWS<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  404. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV> {
  405. struct {
  406. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  407. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  408. union {
  409. struct {
  410. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  411. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  412. };
  413. struct {
  414. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  415. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  416. };
  417. };
  418. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  419. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  420. cute::array_aligned<float, cute::cosize_v<SmemLayoutdQacc>> smem_dqacc;
  421. cute::array_aligned<float, 128> smem_lse;
  422. cute::array_aligned<float, 128> smem_dpsum;
  423. };
  424. struct {
  425. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  426. cutlass::arch::ClusterTransactionBarrier barrier_K;
  427. cutlass::arch::ClusterTransactionBarrier barrier_V;
  428. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  429. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  430. };
  431. };
  432. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  433. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
  434. class SmemLayoutdK, class SmemLayoutdV>
  435. struct SharedStorageQKVdOdKVWS<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  436. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV> {
  437. struct {
  438. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  439. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  440. union {
  441. struct {
  442. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  443. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  444. };
  445. struct {
  446. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  447. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  448. };
  449. };
  450. union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
  451. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  452. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  453. };
  454. cute::array_aligned<float, cute::cosize_v<SmemLayoutdQacc>> smem_dqacc;
  455. cute::array_aligned<float, 128> smem_lse;
  456. cute::array_aligned<float, 128> smem_dpsum;
  457. };
  458. struct {
  459. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  460. cutlass::arch::ClusterTransactionBarrier barrier_K;
  461. cutlass::arch::ClusterTransactionBarrier barrier_V;
  462. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  463. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  464. };
  465. };
  466. template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  467. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  468. class SmemLayoutdQ>
  469. struct SharedStorageQKVdOdKVSeqqPar;
  470. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  471. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  472. class SmemLayoutdQ>
  473. struct SharedStorageQKVdOdKVSeqqPar<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  474. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQ> {
  475. struct {
  476. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  477. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  478. union {
  479. struct {
  480. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  481. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  482. };
  483. struct {
  484. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
  485. };
  486. };
  487. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  488. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  489. };
  490. struct {
  491. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  492. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  493. cutlass::arch::ClusterTransactionBarrier barrier_dO;
  494. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  495. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  496. };
  497. };
  498. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  499. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  500. class SmemLayoutdQ>
  501. struct SharedStorageQKVdOdKVSeqqPar<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  502. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQ> {
  503. struct {
  504. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  505. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  506. union {
  507. struct {
  508. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  509. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  510. };
  511. struct {
  512. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
  513. };
  514. };
  515. union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
  516. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  517. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  518. };
  519. };
  520. struct {
  521. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  522. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  523. cutlass::arch::ClusterTransactionBarrier barrier_dO;
  524. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  525. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  526. };
  527. };
  528. ////////////////////////////////////////////////////////////////////////////////////////////////////
  529. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
  530. bool SdP_swapAB_, bool dKV_swapAB_, bool dQ_swapAB_,
  531. int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
  532. int kClusterN_ = 1, typename elem_type=cutlass::half_t>
  533. struct Flash_bwd_kernel_traits {
  534. using Element = elem_type;
  535. using ElementAccum = float;
  536. using index_t = int64_t;
  537. // The number of threads.
  538. static constexpr int kNWarps = kNWarps_;
  539. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  540. static constexpr int kNThreadsNonWS = 8 * cutlass::NumThreadsPerWarp;
  541. // static constexpr int kNThreadsdQ = cutlass::NumThreadsPerWarpGroup;
  542. static constexpr int kNThreadsdQ = 2 * cutlass::NumThreadsPerWarpGroup;
  543. static_assert(kNWarps_ == 8 || kNWarps_ == 12);
  544. static constexpr bool Is_WS = kNWarps_ >= 12;
  545. static constexpr int kBlockM = kBlockM_;
  546. static constexpr int kBlockN = kBlockN_;
  547. static constexpr int kHeadDim = kHeadDim_;
  548. static_assert(kHeadDim % 32 == 0);
  549. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  550. static constexpr int kClusterN = kClusterN_;
  551. using ClusterShape_MNK = Shape<_1, Int<kClusterN>, _1>;
  552. static constexpr int kStages = 2;
  553. static constexpr bool SdP_swapAB = SdP_swapAB_;
  554. static constexpr bool dKV_swapAB = dKV_swapAB_;
  555. static constexpr bool dQ_swapAB = dQ_swapAB_;
  556. static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV
  557. static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS
  558. using TileShapeAtomSdP = std::conditional_t<
  559. !SdP_swapAB,
  560. Shape<Int<kBlockM>, Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kHeadDim>>,
  561. Shape<Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kBlockM>, Int<kHeadDim>>
  562. >;
  563. using AtomLayoutSdP = std::conditional_t<
  564. !SdP_swapAB,
  565. Layout<Shape<Int<AtomLayoutMSdP>, Int<2 / AtomLayoutMSdP>, _1>>,
  566. Layout<Shape<Int<2 / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>
  567. >;
  568. using TiledMmaSdP = decltype(cute::make_tiled_mma(
  569. cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
  570. AtomLayoutSdP{}));
  571. using TileShapeAtomdKV = std::conditional_t<
  572. !dKV_swapAB,
  573. Shape<Int<kBlockN>, Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockM>>,
  574. Shape<Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockN>, Int<kBlockM>>
  575. >;
  576. using AtomLayoutdKV = std::conditional_t<
  577. !dKV_swapAB,
  578. Layout<Shape<Int<AtomLayoutNdKV>, Int<2 / AtomLayoutNdKV>, _1>>,
  579. Layout<Shape<Int<2 / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>
  580. >;
  581. using TiledMmadKV = decltype(cute::make_tiled_mma(
  582. std::conditional_t<
  583. !SdP_swapAB,
  584. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::MN, GMMA::Major::MN>()),
  585. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::K, GMMA::Major::MN>())
  586. >{},
  587. AtomLayoutdKV{}));
  588. using TileShapeAtomdQ = std::conditional_t<
  589. !dQ_swapAB,
  590. Shape<Int<kBlockM>, Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockN>>,
  591. Shape<Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockM>, Int<kBlockN>>
  592. // Shape<Int<kBlockM>, Int<kHeadDim >, Int<kBlockN>>,
  593. // Shape<Int<kHeadDim>, Int<kBlockM>, Int<kBlockN>>
  594. >;
  595. using AtomLayoutdQ = std::conditional_t<
  596. !dQ_swapAB,
  597. Layout<Shape<Int<AtomLayoutMdQ>, Int<2 / AtomLayoutMdQ>, _1>>,
  598. Layout<Shape<Int<2 / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>
  599. // Layout<Shape<Int<1>, Int<1>, _1>>,
  600. // Layout<Shape<Int<1>, Int<1>, _1>>
  601. >;
  602. static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN;
  603. static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K;
  604. using TiledMmadQ = decltype(cute::make_tiled_mma(
  605. std::conditional_t<
  606. !dQ_swapAB,
  607. std::conditional_t<
  608. Mma_dQ_is_RS,
  609. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>()),
  610. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>())
  611. >,
  612. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::MN, GMMA::Major::K>())
  613. >{},
  614. AtomLayoutdQ{}));
  615. using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
  616. using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
  617. using GmemTiledCopydKV = cute::SM90_TMA_STORE;
  618. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  619. static constexpr bool Has_cp_async = true;
  620. #else
  621. static constexpr bool Has_cp_async = false;
  622. #endif
  623. // For the dot_do_o preprocessing kernel
  624. using Gmem_copy_struct = std::conditional_t<
  625. Has_cp_async,
  626. SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
  627. DefaultCopy
  628. >;
  629. static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
  630. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  631. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
  632. // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
  633. // to affect speed in practice.
  634. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
  635. static_assert(kNThreadsNonWS % kGmemThreadsPerRow == 0, "kNThreadsNonWS must be a multiple of kGmemThreadsPerRow");
  636. using GmemLayoutAtom = Layout<Shape <Int<kNThreadsNonWS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  637. Stride<Int<kGmemThreadsPerRow>, _1>>;
  638. using GmemLayoutAtomdQ = Layout<Shape <Int<kNThreadsdQ / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  639. Stride<Int<kGmemThreadsPerRow>, _1>>;
  640. using GmemTiledCopydO = decltype(
  641. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  642. GmemLayoutAtom{},
  643. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  644. using GmemTiledCopydQ = decltype(
  645. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  646. GmemLayoutAtomdQ{},
  647. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  648. using GmemLayoutAtomdQaccum = std::conditional_t<
  649. kBlockKSmem == 32,
  650. Layout<Shape <Int<kNThreadsdQ / 8>, _8>, // Thread layout, 8 threads per row
  651. Stride< _8, _1>>,
  652. Layout<Shape <Int<kNThreadsdQ / 16>, _16>, // Thread layout, 16 threads per row
  653. Stride< _16, _1>>
  654. >;
  655. using GmemTiledCopydQaccum = decltype(
  656. make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
  657. GmemLayoutAtomdQaccum{},
  658. Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
  659. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  660. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  661. using SmemLayoutQ =
  662. decltype(tile_to_shape(SmemLayoutAtomQ{},
  663. make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  664. using SmemLayoutdO = SmemLayoutQ;
  665. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  666. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  667. using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
  668. using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  669. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  670. using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{})));
  671. using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  672. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  673. using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{})));
  674. using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  675. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  676. using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, select<0, 1>(TileShape_MNK{})));
  677. // using SmemLayoutAtomdQacc = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, ElementAccum,
  678. // decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  679. // using SmemLayoutdQacc = decltype(tile_to_shape(SmemLayoutAtomdQacc{}, select<0, 2>(TileShape_MNK{})));
  680. // Note this is the transpose in terms of the view, not in terms of memory.
  681. using SmemLayoutQt =
  682. decltype(cute::composition(SmemLayoutQ{},
  683. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
  684. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
  685. using SmemLayoutdOt =
  686. decltype(cute::composition(SmemLayoutdO{},
  687. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
  688. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
  689. using SmemLayoutKt =
  690. decltype(cute::composition(SmemLayoutK{},
  691. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  692. make_stride(Int<kBlockN>{}, _1{}))));
  693. using SmemLayoutPt =
  694. decltype(cute::composition(SmemLayoutP{},
  695. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  696. make_stride(Int<kBlockM>{}, _1{}))));
  697. using SmemLayoutdSt =
  698. decltype(cute::composition(SmemLayoutdS{},
  699. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  700. make_stride(Int<kBlockM>{}, _1{}))));
  701. // using SmemLayoutdQacct =
  702. // decltype(cute::composition(SmemLayoutdQacc{},
  703. // make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  704. // make_stride(Int<kBlockM>{}, _1{}))));
  705. using SmemLayoutdK = SmemLayoutK;
  706. using SmemLayoutdV = SmemLayoutV;
  707. using SmemLayoutdKt = SmemLayoutKt;
  708. using SmemLayoutdVt = SmemLayoutKt;
  709. static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
  710. using SmemLayoutAtomdQ = decltype(
  711. // composition(Swizzle<kSwizzle, 3, 3>{},
  712. composition(Swizzle<3, 3, 3>{},
  713. Layout<Shape<Int<kNThreadsdQ / 32>, Int<32>>,
  714. Stride<Int<32>, _1>>{}));
  715. using SmemLayoutdQ = decltype(tile_to_shape(
  716. SmemLayoutAtomdQ{},
  717. make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
  718. using SmemLayoutdQt =
  719. decltype(cute::composition(SmemLayoutdQ{},
  720. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  721. make_stride(Int<kBlockM>{}, _1{}))));
  722. static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
  723. using SmemLayoutAtomdQaccTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, ElementAccum,
  724. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  725. using SmemLayoutdQaccTMA = decltype(tile_to_shape(SmemLayoutAtomdQaccTMA{}, select<0, 2>(TileShape_MNK{})));
  726. using SmemLayoutdQacc = SmemLayoutdQ;
  727. using SmemLayoutdQacct = SmemLayoutdQt;
  728. using SmemLayoutdQacc2 = decltype(tile_to_shape(
  729. SmemLayoutAtomdQ{},
  730. make_shape(Int<kBlockM>{}, Int<kHeadDim>{}, _2{})));
  731. // using SmemLayoutdQacc = decltype(tile_to_shape(SmemLayoutAtomdQacc{}, select<0, 2>(TileShape_MNK{})));
  732. // using SmemLayoutdQacct =
  733. // decltype(cute::composition(SmemLayoutdQacc{},
  734. // make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  735. // make_stride(Int<kBlockM>{}, _1{}))));
  736. using RmemTiledCopydQacc = decltype(
  737. make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
  738. GmemLayoutAtomdQaccum{},
  739. Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
  740. // using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  741. using SmemCopyAtomPdS = Copy_Atom<
  742. std::conditional_t<!SdP_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  743. Element>;
  744. using SmemCopyAtomdKV = Copy_Atom<
  745. std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  746. Element>;
  747. using SmemCopyAtomdQ = Copy_Atom<
  748. std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  749. Element>;
  750. using SharedStorage = std::conditional_t<
  751. !Is_WS,
  752. SharedStorageQKVdOdKV<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
  753. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV>,
  754. SharedStorageQKVdOdKVWS<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
  755. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV>
  756. // SmemLayoutK, SmemLayoutV, SmemLayoutdS, SmemLayoutdQacc2, SmemLayoutdK, SmemLayoutdV>
  757. >;
  758. // using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages * 2>;
  759. // using PipelineState = typename cutlass::PipelineState<kStages * 2>;
  760. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  761. };
  762. ////////////////////////////////////////////////////////////////////////////////////////////////////
  763. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
  764. bool SdP_swapAB_, bool dKV_swapAB_, bool dQ_swapAB_,
  765. int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
  766. int kClusterN_ = 1, typename elem_type=cutlass::half_t>
  767. struct Flash_bwd_seqqpar_kernel_traits {
  768. using Element = elem_type;
  769. using ElementAccum = float;
  770. using index_t = int64_t;
  771. // The number of threads.
  772. static constexpr int kNWarps = kNWarps_;
  773. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  774. static_assert(kNWarps_ == 8);
  775. static constexpr int kBlockM = kBlockM_;
  776. static constexpr int kBlockN = kBlockN_;
  777. static constexpr int kHeadDim = kHeadDim_;
  778. static_assert(kHeadDim % 32 == 0);
  779. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  780. static constexpr int kClusterN = kClusterN_;
  781. using ClusterShape_MNK = Shape<_1, Int<kClusterN>, _1>;
  782. static constexpr int kStages = 2;
  783. static constexpr bool SdP_swapAB = SdP_swapAB_;
  784. static constexpr bool dKV_swapAB = dKV_swapAB_;
  785. static constexpr bool dQ_swapAB = dQ_swapAB_;
  786. static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV
  787. static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS
  788. using TileShapeAtomSdP = std::conditional_t<
  789. !SdP_swapAB,
  790. Shape<Int<kBlockM>, Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kHeadDim>>,
  791. Shape<Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kBlockM>, Int<kHeadDim>>
  792. >;
  793. using AtomLayoutSdP = std::conditional_t<
  794. !SdP_swapAB,
  795. Layout<Shape<Int<AtomLayoutMSdP>, Int<2 / AtomLayoutMSdP>, _1>>,
  796. Layout<Shape<Int<2 / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>
  797. >;
  798. using TiledMmaSdP = decltype(cute::make_tiled_mma(
  799. cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
  800. AtomLayoutSdP{}));
  801. using TileShapeAtomdKV = std::conditional_t<
  802. !dKV_swapAB,
  803. Shape<Int<kBlockN>, Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockM>>,
  804. Shape<Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockN>, Int<kBlockM>>
  805. >;
  806. using AtomLayoutdKV = std::conditional_t<
  807. !dKV_swapAB,
  808. Layout<Shape<Int<AtomLayoutNdKV>, Int<2 / AtomLayoutNdKV>, _1>>,
  809. Layout<Shape<Int<2 / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>
  810. >;
  811. using TiledMmadKV = decltype(cute::make_tiled_mma(
  812. std::conditional_t<
  813. !SdP_swapAB,
  814. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::MN, GMMA::Major::MN>()),
  815. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::K, GMMA::Major::MN>())
  816. >{},
  817. AtomLayoutdKV{}));
  818. using TileShapeAtomdQ = std::conditional_t<
  819. !dQ_swapAB,
  820. Shape<Int<kBlockM>, Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockN>>,
  821. Shape<Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockM>, Int<kBlockN>>
  822. >;
  823. using AtomLayoutdQ = std::conditional_t<
  824. !dQ_swapAB,
  825. Layout<Shape<Int<AtomLayoutMdQ>, Int<2 / AtomLayoutMdQ>, _1>>,
  826. Layout<Shape<Int<2 / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>
  827. >;
  828. static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN;
  829. static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K;
  830. using TiledMmadQ = decltype(cute::make_tiled_mma(
  831. std::conditional_t<
  832. !dQ_swapAB,
  833. std::conditional_t<
  834. Mma_dQ_is_RS,
  835. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>()),
  836. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>())
  837. >,
  838. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::MN, GMMA::Major::K>())
  839. >{},
  840. AtomLayoutdQ{}));
  841. using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
  842. using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
  843. using GmemTiledCopydKV = cute::SM90_TMA_STORE;
  844. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  845. static constexpr bool Has_cp_async = true;
  846. #else
  847. static constexpr bool Has_cp_async = false;
  848. #endif
  849. // For the dot_do_o preprocessing kernel
  850. using Gmem_copy_struct = std::conditional_t<
  851. Has_cp_async,
  852. SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
  853. DefaultCopy
  854. >;
  855. static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
  856. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  857. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
  858. // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
  859. // to affect speed in practice.
  860. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
  861. static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
  862. using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  863. Stride<Int<kGmemThreadsPerRow>, _1>>;
  864. using GmemTiledCopydO = decltype(
  865. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  866. GmemLayoutAtom{},
  867. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  868. using GmemTiledCopydQ = decltype(
  869. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  870. GmemLayoutAtom{},
  871. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  872. using GmemLayoutAtomdQaccum = std::conditional_t<
  873. kBlockKSmem == 32,
  874. Layout<Shape <_32, _8>, // Thread layout, 8 threads per row
  875. Stride< _8, _1>>,
  876. Layout<Shape <_16, _16>, // Thread layout, 16 threads per row
  877. Stride< _16, _1>>
  878. >;
  879. using GmemTiledCopydQaccum = decltype(
  880. make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
  881. GmemLayoutAtomdQaccum{},
  882. Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
  883. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  884. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  885. using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
  886. using SmemLayoutdO = SmemLayoutQ;
  887. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  888. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  889. using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{},
  890. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  891. using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  892. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  893. using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{},
  894. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  895. using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  896. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  897. using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{})));
  898. using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  899. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  900. using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, select<0, 1>(TileShape_MNK{})));
  901. // Note this is the transpose in terms of the view, not in terms of memory.
  902. using SmemLayoutQt =
  903. decltype(cute::composition(SmemLayoutQ{},
  904. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  905. make_stride(Int<kBlockM>{}, _1{}))));
  906. using SmemLayoutdOt =
  907. decltype(cute::composition(SmemLayoutdO{},
  908. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  909. make_stride(Int<kBlockM>{}, _1{}))));
  910. using SmemLayoutKt =
  911. decltype(cute::composition(SmemLayoutK{},
  912. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
  913. make_stride(Int<kBlockN>{}, _1{}, Int<kBlockN * kHeadDim>{}))));
  914. using SmemLayoutPt =
  915. decltype(cute::composition(SmemLayoutP{},
  916. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  917. make_stride(Int<kBlockM>{}, _1{}))));
  918. using SmemLayoutdSt =
  919. decltype(cute::composition(SmemLayoutdS{},
  920. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  921. make_stride(Int<kBlockM>{}, _1{}))));
  922. using SmemLayoutdK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
  923. using SmemLayoutdV = SmemLayoutdK;
  924. using SmemLayoutdKt = SmemLayoutKt;
  925. using SmemLayoutdVt = SmemLayoutKt;
  926. using SmemLayoutdQTMA = decltype(tile_to_shape(SmemLayoutAtomK{}, select<0, 2>(TileShape_MNK{})));
  927. static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
  928. using SmemLayoutAtomdQ = decltype(
  929. composition(Swizzle<kSwizzle, 3, 3>{},
  930. Layout<Shape<_8, Int<kBlockKSmem>>,
  931. Stride<Int<kBlockKSmem>, _1>>{}));
  932. using SmemLayoutdQ = decltype(tile_to_shape(
  933. SmemLayoutAtomdQ{},
  934. make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
  935. using SmemLayoutdQt =
  936. decltype(cute::composition(SmemLayoutdQ{},
  937. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  938. make_stride(Int<kBlockM>{}, _1{}))));
  939. static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
  940. using SmemLayoutAtomdKV = decltype(
  941. composition(Swizzle<kSwizzle, 3, 3>{},
  942. Layout<Shape<_8, Int<kBlockKSmem>>,
  943. Stride<Int<kBlockKSmem>, _1>>{}));
  944. using SmemLayoutdKV = decltype(tile_to_shape(
  945. SmemLayoutAtomdKV{},
  946. make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
  947. using SmemLayoutdKVt =
  948. decltype(cute::composition(SmemLayoutdKV{},
  949. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  950. make_stride(Int<kBlockN>{}, _1{}))));
  951. static constexpr int kSmemdKVSize = size(SmemLayoutdKV{}) * sizeof(Element) * 2;
  952. // using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  953. using SmemCopyAtomPdS = Copy_Atom<
  954. std::conditional_t<!SdP_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  955. Element>;
  956. using SmemCopyAtomdKV = Copy_Atom<
  957. std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  958. Element>;
  959. using SmemCopyAtomdQ = Copy_Atom<
  960. std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  961. Element>;
  962. using SharedStorage = SharedStorageQKVdOdKVSeqqPar<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
  963. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQTMA>;
  964. // using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages * 2>;
  965. // using PipelineState = typename cutlass::PipelineState<kStages * 2>;
  966. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  967. };
  968. ////////////////////////////////////////////////////////////////////////////////////////////////////