kernel_traits.h 50 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/algorithm/copy.hpp"
  6. #include "cute/atom/mma_atom.hpp"
  7. #include "cutlass/gemm/collective/collective_builder.hpp"
  8. #include "cutlass/cutlass.h"
  9. #include "cutlass/layout/layout.h"
  10. #include "cutlass/numeric_types.h"
  11. #include "cutlass/pipeline/pipeline.hpp"
  12. using namespace cute;
  13. template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
  14. class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
  15. struct SharedStorageQKVO {
  16. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
  17. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
  18. union {
  19. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
  20. cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
  21. };
  22. struct {
  23. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  24. cutlass::arch::ClusterBarrier barrier_O;
  25. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  26. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  27. int tile_count_semaphore;
  28. };
  29. };
  30. template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
  31. class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
  32. struct SharedStorageQKVOVt {
  33. struct {
  34. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
  35. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
  36. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
  37. union {
  38. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v_out;
  39. cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
  40. };
  41. };
  42. struct {
  43. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  44. cutlass::arch::ClusterBarrier barrier_O;
  45. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  46. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  47. typename cutlass::PipelineAsync<kStages>::SharedStorage pipeline_vt;
  48. int tile_count_semaphore;
  49. };
  50. };
  51. // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
  52. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool Is_Q_in_regs_=false,
  53. int kClusterM_ = 1, typename elem_type=cutlass::half_t>
  54. struct Flash_fwd_kernel_traits {
  55. using Element = elem_type;
  56. using ElementAccum = float;
  57. using OutputType = elem_type;
  58. using index_t = int64_t;
  59. // The number of threads.
  60. static constexpr int kNWarps = kNWarps_;
  61. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  62. static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarp;
  63. static constexpr bool Is_Q_in_regs = Is_Q_in_regs_;
  64. static_assert(kNWarps_ == 4 || kNWarps_ == 8 || kNWarps_ == 12 || kNWarps_ == 16);
  65. static constexpr bool Is_WS = kNWarps_ >= 12;
  66. static_assert(!(Is_WS && Is_Q_in_regs), "Warp-specialization does not support Q in registers");
  67. static constexpr int kBlockM = kBlockM_;
  68. static constexpr int kBlockN = kBlockN_;
  69. static constexpr int kHeadDim = kHeadDim_;
  70. static_assert(kHeadDim % 32 == 0);
  71. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  72. static constexpr int kClusterM = kClusterM_;
  73. using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
  74. static constexpr int kStages = kStages_;
  75. using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
  76. using TiledMma0 = decltype(cute::make_tiled_mma(
  77. std::conditional_t<
  78. Is_Q_in_regs,
  79. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>()),
  80. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>())
  81. >{},
  82. AtomLayoutMNK{}));
  83. using TiledMma1 = decltype(cute::make_tiled_mma(
  84. cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{})),
  85. GMMA::Major::K, GMMA::Major::MN>(),
  86. AtomLayoutMNK{}));
  87. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  88. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  89. using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
  90. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  91. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  92. using SmemLayoutK =
  93. decltype(tile_to_shape(SmemLayoutAtomK{},
  94. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  95. using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  96. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  97. using SmemLayoutV =
  98. decltype(tile_to_shape(SmemLayoutAtomV{},
  99. make_shape(get<1>(TileShape_MNK{}), get<2>(TileShape_MNK{}), Int<kStages>{})));
  100. // Note this is the transpose in terms of the view, not in terms of memory.
  101. using SmemLayoutVt =
  102. decltype(composition(SmemLayoutV{},
  103. make_ordered_layout(
  104. make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
  105. Step<_2, _1, _3>{})));
  106. using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, OutputType,
  107. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  108. using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
  109. using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  110. using SharedStorage = SharedStorageQKVO<kStages, Element, Element, Element, SmemLayoutQ,
  111. SmemLayoutK, SmemLayoutV, SmemLayoutO>;
  112. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  113. using MainloopPipelineNoTMA = typename cutlass::PipelineAsync<kStages>;
  114. using PipelineState = typename cutlass::PipelineState<kStages>;
  115. // using BarrierType = typename MainloopPipeline::ProducerBarrierType;
  116. };
  117. // Traits struct for fp8 kernel with in-kernel transpose
  118. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool Is_Q_in_regs_=false,
  119. int kClusterM_ = 1, typename elem_type=cutlass::float_e4m3_t>
  120. struct Flash_fwd_kernel_traits_fp8 {
  121. using Element = elem_type;
  122. static_assert(cutlass::sizeof_bits_v<Element> == 8);
  123. using ElementAccum = float;
  124. using OutputType = cutlass::half_t;
  125. using index_t = int64_t;
  126. // The number of threads.
  127. static constexpr int kNWarps = kNWarps_;
  128. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  129. static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
  130. static constexpr bool Is_Q_in_regs = Is_Q_in_regs_;
  131. static_assert(kNWarps_ == 12 || kNWarps_ == 16);
  132. static constexpr bool Is_WS = true;
  133. static_assert(!Is_Q_in_regs, "Warp-specialization does not support Q in registers");
  134. static constexpr int kBlockM = kBlockM_;
  135. static constexpr int kBlockN = kBlockN_;
  136. static constexpr int kHeadDim = kHeadDim_;
  137. static_assert(kHeadDim % 32 == 0);
  138. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  139. static constexpr int kClusterM = kClusterM_;
  140. using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
  141. static constexpr int kStages = kStages_;
  142. static_assert(kStages > 1);
  143. using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
  144. using TiledMma0 = decltype(cute::make_tiled_mma(
  145. cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
  146. AtomLayoutMNK{}));
  147. using TiledMma1 = decltype(cute::make_tiled_mma(
  148. cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{}))>(),
  149. AtomLayoutMNK{}));
  150. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  151. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  152. using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
  153. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  154. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  155. using SmemLayoutK =
  156. decltype(tile_to_shape(SmemLayoutAtomK{},
  157. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  158. using TransposeShapeAtomV = Shape<_64, _64>;
  159. using SmemLayoutAtomV = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
  160. using SmemLayoutV =
  161. decltype(tile_to_shape(SmemLayoutAtomV{},
  162. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  163. // for fp8 in-kernel transpose -- src layout
  164. using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{}));
  165. using SmemShapeLDSM = Shape<Shape<_8, _8>, Shape<_16, _4>>;
  166. using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{},
  167. shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{}), shape<3>(SmemLayoutDivideV{})));
  168. using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{})));
  169. // For fp8, this is the memory transpose.
  170. using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
  171. using SmemLayoutVt =
  172. decltype(tile_to_shape(SmemLayoutAtomVt{},
  173. make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{})));
  174. // for fp8 in-kernel transpose -- dst layout
  175. using SmemLayoutVtTrans =
  176. decltype(composition(SmemLayoutVt{},
  177. make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1, _3>{})));
  178. using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{}));
  179. #ifndef NO_FP8_COLUMN_PERMUTE
  180. using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_8, _8>>;
  181. #else
  182. using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_16, _4>>;
  183. #endif
  184. using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{},
  185. shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{}), shape<3>(SmemLayoutDivideVt{})));
  186. using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{})));
  187. using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, OutputType,
  188. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  189. using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
  190. // used for rmem -> smem O copy in fp8 kernel to undo column permutation
  191. using ThreadLayoutrO = Layout<Shape<_8, Int<kBlockM/16>, _4, _1>,
  192. Stride<_4, _32, _1, _0>>;
  193. using ValueLayoutrO = Layout<Shape<_1, _2, Shape<_2, _2>, Int<kHeadDim/16>>,
  194. Stride<_0, _2, Stride<_4, _1>, _8>>;
  195. using TiledCopyrO = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<uint16_t>, OutputType>{},
  196. ThreadLayoutrO{}, ValueLayoutrO{}));
  197. using TiledCopyShaperO = Shape<_8, Int<kBlockM/8>, _16, Int<kHeadDim/16>>;
  198. using SmemLayoutrO = decltype(composition(SmemLayoutO{}, Layout<TiledCopyShaperO>{}));
  199. using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  200. using SharedStorage = SharedStorageQKVOVt<kStages, Element, Element, OutputType, SmemLayoutQ,
  201. SmemLayoutK, SmemLayoutV, SmemLayoutO>;
  202. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  203. using MainloopPipelineNoTMA = typename cutlass::PipelineAsync<kStages>;
  204. using PipelineState = typename cutlass::PipelineState<kStages>;
  205. // using BarrierType = typename MainloopPipeline::ProducerBarrierType;
  206. };
  207. ////////////////////////////////////////////////////////////////////////////////////////////////////
  208. template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  209. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  210. class SmemLayoutdK, class SmemLayoutdV>
  211. struct SharedStorageQKVdOdKV;
  212. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  213. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  214. class SmemLayoutdK, class SmemLayoutdV>
  215. struct SharedStorageQKVdOdKV<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  216. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV> {
  217. struct {
  218. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  219. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  220. union {
  221. struct {
  222. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  223. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  224. };
  225. struct {
  226. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  227. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  228. };
  229. };
  230. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  231. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  232. };
  233. struct {
  234. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  235. cutlass::arch::ClusterTransactionBarrier barrier_K;
  236. cutlass::arch::ClusterTransactionBarrier barrier_V;
  237. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  238. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  239. };
  240. };
  241. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  242. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  243. class SmemLayoutdK, class SmemLayoutdV>
  244. struct SharedStorageQKVdOdKV<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  245. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV> {
  246. struct {
  247. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  248. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  249. union {
  250. struct {
  251. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  252. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  253. };
  254. struct {
  255. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  256. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  257. };
  258. };
  259. union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
  260. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  261. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  262. };
  263. };
  264. struct {
  265. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  266. cutlass::arch::ClusterTransactionBarrier barrier_K;
  267. cutlass::arch::ClusterTransactionBarrier barrier_V;
  268. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  269. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  270. };
  271. };
  272. template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  273. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
  274. class SmemLayoutdK, class SmemLayoutdV>
  275. struct SharedStorageQKVdOdKVWS;
  276. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  277. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
  278. class SmemLayoutdK, class SmemLayoutdV>
  279. struct SharedStorageQKVdOdKVWS<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  280. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV> {
  281. struct {
  282. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  283. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  284. union {
  285. struct {
  286. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  287. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  288. };
  289. struct {
  290. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  291. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  292. };
  293. };
  294. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  295. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  296. cute::array_aligned<float, cute::cosize_v<SmemLayoutdQacc>> smem_dqacc;
  297. cute::array_aligned<float, 128> smem_lse;
  298. cute::array_aligned<float, 128> smem_dpsum;
  299. };
  300. struct {
  301. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  302. cutlass::arch::ClusterTransactionBarrier barrier_K;
  303. cutlass::arch::ClusterTransactionBarrier barrier_V;
  304. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  305. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  306. };
  307. };
  308. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  309. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
  310. class SmemLayoutdK, class SmemLayoutdV>
  311. struct SharedStorageQKVdOdKVWS<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  312. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV> {
  313. struct {
  314. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  315. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  316. union {
  317. struct {
  318. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  319. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  320. };
  321. struct {
  322. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  323. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  324. };
  325. };
  326. union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
  327. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  328. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  329. };
  330. cute::array_aligned<float, cute::cosize_v<SmemLayoutdQacc>> smem_dqacc;
  331. cute::array_aligned<float, 128> smem_lse;
  332. cute::array_aligned<float, 128> smem_dpsum;
  333. };
  334. struct {
  335. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  336. cutlass::arch::ClusterTransactionBarrier barrier_K;
  337. cutlass::arch::ClusterTransactionBarrier barrier_V;
  338. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  339. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  340. };
  341. };
  342. template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  343. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  344. class SmemLayoutdQ>
  345. struct SharedStorageQKVdOdKVSeqqPar;
  346. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  347. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  348. class SmemLayoutdQ>
  349. struct SharedStorageQKVdOdKVSeqqPar<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  350. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQ> {
  351. struct {
  352. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  353. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  354. union {
  355. struct {
  356. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  357. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  358. };
  359. struct {
  360. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
  361. };
  362. };
  363. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  364. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  365. };
  366. struct {
  367. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  368. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  369. cutlass::arch::ClusterTransactionBarrier barrier_dO;
  370. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  371. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  372. };
  373. };
  374. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  375. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  376. class SmemLayoutdQ>
  377. struct SharedStorageQKVdOdKVSeqqPar<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  378. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQ> {
  379. struct {
  380. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  381. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  382. union {
  383. struct {
  384. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  385. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  386. };
  387. struct {
  388. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
  389. };
  390. };
  391. union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
  392. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  393. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  394. };
  395. };
  396. struct {
  397. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  398. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  399. cutlass::arch::ClusterTransactionBarrier barrier_dO;
  400. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  401. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  402. };
  403. };
  404. ////////////////////////////////////////////////////////////////////////////////////////////////////
  405. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
  406. bool SdP_swapAB_, bool dKV_swapAB_, bool dQ_swapAB_,
  407. int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
  408. int kClusterN_ = 1, typename elem_type=cutlass::half_t>
  409. struct Flash_bwd_kernel_traits {
  410. using Element = elem_type;
  411. using ElementAccum = float;
  412. using index_t = int64_t;
  413. // The number of threads.
  414. static constexpr int kNWarps = kNWarps_;
  415. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  416. static constexpr int kNThreadsNonWS = 8 * cutlass::NumThreadsPerWarp;
  417. // static constexpr int kNThreadsdQ = cutlass::NumThreadsPerWarpGroup;
  418. static constexpr int kNThreadsdQ = 2 * cutlass::NumThreadsPerWarpGroup;
  419. static_assert(kNWarps_ == 8 || kNWarps_ == 12);
  420. static constexpr bool Is_WS = kNWarps_ >= 12;
  421. static constexpr int kBlockM = kBlockM_;
  422. static constexpr int kBlockN = kBlockN_;
  423. static constexpr int kHeadDim = kHeadDim_;
  424. static_assert(kHeadDim % 32 == 0);
  425. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  426. static constexpr int kClusterN = kClusterN_;
  427. using ClusterShape_MNK = Shape<_1, Int<kClusterN>, _1>;
  428. static constexpr int kStages = 2;
  429. static constexpr bool SdP_swapAB = SdP_swapAB_;
  430. static constexpr bool dKV_swapAB = dKV_swapAB_;
  431. static constexpr bool dQ_swapAB = dQ_swapAB_;
  432. static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV
  433. static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS
  434. using TileShapeAtomSdP = std::conditional_t<
  435. !SdP_swapAB,
  436. Shape<Int<kBlockM>, Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kHeadDim>>,
  437. Shape<Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kBlockM>, Int<kHeadDim>>
  438. >;
  439. using AtomLayoutSdP = std::conditional_t<
  440. !SdP_swapAB,
  441. Layout<Shape<Int<AtomLayoutMSdP>, Int<2 / AtomLayoutMSdP>, _1>>,
  442. Layout<Shape<Int<2 / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>
  443. >;
  444. using TiledMmaSdP = decltype(cute::make_tiled_mma(
  445. cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
  446. AtomLayoutSdP{}));
  447. using TileShapeAtomdKV = std::conditional_t<
  448. !dKV_swapAB,
  449. Shape<Int<kBlockN>, Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockM>>,
  450. Shape<Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockN>, Int<kBlockM>>
  451. >;
  452. using AtomLayoutdKV = std::conditional_t<
  453. !dKV_swapAB,
  454. Layout<Shape<Int<AtomLayoutNdKV>, Int<2 / AtomLayoutNdKV>, _1>>,
  455. Layout<Shape<Int<2 / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>
  456. >;
  457. using TiledMmadKV = decltype(cute::make_tiled_mma(
  458. std::conditional_t<
  459. !SdP_swapAB,
  460. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::MN, GMMA::Major::MN>()),
  461. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::K, GMMA::Major::MN>())
  462. >{},
  463. AtomLayoutdKV{}));
  464. using TileShapeAtomdQ = std::conditional_t<
  465. !dQ_swapAB,
  466. Shape<Int<kBlockM>, Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockN>>,
  467. Shape<Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockM>, Int<kBlockN>>
  468. // Shape<Int<kBlockM>, Int<kHeadDim >, Int<kBlockN>>,
  469. // Shape<Int<kHeadDim>, Int<kBlockM>, Int<kBlockN>>
  470. >;
  471. using AtomLayoutdQ = std::conditional_t<
  472. !dQ_swapAB,
  473. Layout<Shape<Int<AtomLayoutMdQ>, Int<2 / AtomLayoutMdQ>, _1>>,
  474. Layout<Shape<Int<2 / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>
  475. // Layout<Shape<Int<1>, Int<1>, _1>>,
  476. // Layout<Shape<Int<1>, Int<1>, _1>>
  477. >;
  478. static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN;
  479. static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K;
  480. using TiledMmadQ = decltype(cute::make_tiled_mma(
  481. std::conditional_t<
  482. !dQ_swapAB,
  483. std::conditional_t<
  484. Mma_dQ_is_RS,
  485. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>()),
  486. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>())
  487. >,
  488. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::MN, GMMA::Major::K>())
  489. >{},
  490. AtomLayoutdQ{}));
  491. using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
  492. using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
  493. using GmemTiledCopydKV = cute::SM90_TMA_STORE;
  494. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  495. static constexpr bool Has_cp_async = true;
  496. #else
  497. static constexpr bool Has_cp_async = false;
  498. #endif
  499. // For the dot_do_o preprocessing kernel
  500. using Gmem_copy_struct = std::conditional_t<
  501. Has_cp_async,
  502. SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
  503. DefaultCopy
  504. >;
  505. static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
  506. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  507. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
  508. // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
  509. // to affect speed in practice.
  510. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
  511. static_assert(kNThreadsNonWS % kGmemThreadsPerRow == 0, "kNThreadsNonWS must be a multiple of kGmemThreadsPerRow");
  512. using GmemLayoutAtom = Layout<Shape <Int<kNThreadsNonWS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  513. Stride<Int<kGmemThreadsPerRow>, _1>>;
  514. using GmemLayoutAtomdQ = Layout<Shape <Int<kNThreadsdQ / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  515. Stride<Int<kGmemThreadsPerRow>, _1>>;
  516. using GmemTiledCopydO = decltype(
  517. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  518. GmemLayoutAtom{},
  519. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  520. using GmemTiledCopydQ = decltype(
  521. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  522. GmemLayoutAtomdQ{},
  523. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  524. using GmemLayoutAtomdQaccum = std::conditional_t<
  525. kBlockKSmem == 32,
  526. Layout<Shape <Int<kNThreadsdQ / 8>, _8>, // Thread layout, 8 threads per row
  527. Stride< _8, _1>>,
  528. Layout<Shape <Int<kNThreadsdQ / 16>, _16>, // Thread layout, 16 threads per row
  529. Stride< _16, _1>>
  530. >;
  531. using GmemTiledCopydQaccum = decltype(
  532. make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
  533. GmemLayoutAtomdQaccum{},
  534. Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
  535. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  536. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  537. using SmemLayoutQ =
  538. decltype(tile_to_shape(SmemLayoutAtomQ{},
  539. make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  540. using SmemLayoutdO = SmemLayoutQ;
  541. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  542. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  543. using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
  544. using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  545. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  546. using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{})));
  547. using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  548. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  549. using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{})));
  550. using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  551. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  552. using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, select<0, 1>(TileShape_MNK{})));
  553. // using SmemLayoutAtomdQacc = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, ElementAccum,
  554. // decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  555. // using SmemLayoutdQacc = decltype(tile_to_shape(SmemLayoutAtomdQacc{}, select<0, 2>(TileShape_MNK{})));
  556. // Note this is the transpose in terms of the view, not in terms of memory.
  557. using SmemLayoutQt =
  558. decltype(cute::composition(SmemLayoutQ{},
  559. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
  560. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
  561. using SmemLayoutdOt =
  562. decltype(cute::composition(SmemLayoutdO{},
  563. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
  564. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
  565. using SmemLayoutKt =
  566. decltype(cute::composition(SmemLayoutK{},
  567. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  568. make_stride(Int<kBlockN>{}, _1{}))));
  569. using SmemLayoutPt =
  570. decltype(cute::composition(SmemLayoutP{},
  571. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  572. make_stride(Int<kBlockM>{}, _1{}))));
  573. using SmemLayoutdSt =
  574. decltype(cute::composition(SmemLayoutdS{},
  575. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  576. make_stride(Int<kBlockM>{}, _1{}))));
  577. // using SmemLayoutdQacct =
  578. // decltype(cute::composition(SmemLayoutdQacc{},
  579. // make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  580. // make_stride(Int<kBlockM>{}, _1{}))));
  581. using SmemLayoutdK = SmemLayoutK;
  582. using SmemLayoutdV = SmemLayoutV;
  583. using SmemLayoutdKt = SmemLayoutKt;
  584. using SmemLayoutdVt = SmemLayoutKt;
  585. static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
  586. using SmemLayoutAtomdQ = decltype(
  587. // composition(Swizzle<kSwizzle, 3, 3>{},
  588. composition(Swizzle<3, 3, 3>{},
  589. Layout<Shape<Int<kNThreadsdQ / 32>, Int<32>>,
  590. Stride<Int<32>, _1>>{}));
  591. using SmemLayoutdQ = decltype(tile_to_shape(
  592. SmemLayoutAtomdQ{},
  593. make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
  594. using SmemLayoutdQt =
  595. decltype(cute::composition(SmemLayoutdQ{},
  596. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  597. make_stride(Int<kBlockM>{}, _1{}))));
  598. static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
  599. using SmemLayoutAtomdQaccTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, ElementAccum,
  600. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  601. using SmemLayoutdQaccTMA = decltype(tile_to_shape(SmemLayoutAtomdQaccTMA{}, select<0, 2>(TileShape_MNK{})));
  602. using SmemLayoutdQacc = SmemLayoutdQ;
  603. using SmemLayoutdQacct = SmemLayoutdQt;
  604. using SmemLayoutdQacc2 = decltype(tile_to_shape(
  605. SmemLayoutAtomdQ{},
  606. make_shape(Int<kBlockM>{}, Int<kHeadDim>{}, _2{})));
  607. // using SmemLayoutdQacc = decltype(tile_to_shape(SmemLayoutAtomdQacc{}, select<0, 2>(TileShape_MNK{})));
  608. // using SmemLayoutdQacct =
  609. // decltype(cute::composition(SmemLayoutdQacc{},
  610. // make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  611. // make_stride(Int<kBlockM>{}, _1{}))));
  612. using RmemTiledCopydQacc = decltype(
  613. make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
  614. GmemLayoutAtomdQaccum{},
  615. Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
  616. // using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  617. using SmemCopyAtomPdS = Copy_Atom<
  618. std::conditional_t<!SdP_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  619. Element>;
  620. using SmemCopyAtomdKV = Copy_Atom<
  621. std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  622. Element>;
  623. using SmemCopyAtomdQ = Copy_Atom<
  624. std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  625. Element>;
  626. using SharedStorage = std::conditional_t<
  627. !Is_WS,
  628. SharedStorageQKVdOdKV<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
  629. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV>,
  630. SharedStorageQKVdOdKVWS<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
  631. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV>
  632. // SmemLayoutK, SmemLayoutV, SmemLayoutdS, SmemLayoutdQacc2, SmemLayoutdK, SmemLayoutdV>
  633. >;
  634. // using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages * 2>;
  635. // using PipelineState = typename cutlass::PipelineState<kStages * 2>;
  636. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  637. };
  638. ////////////////////////////////////////////////////////////////////////////////////////////////////
  639. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
  640. bool SdP_swapAB_, bool dKV_swapAB_, bool dQ_swapAB_,
  641. int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
  642. int kClusterN_ = 1, typename elem_type=cutlass::half_t>
  643. struct Flash_bwd_seqqpar_kernel_traits {
  644. using Element = elem_type;
  645. using ElementAccum = float;
  646. using index_t = int64_t;
  647. // The number of threads.
  648. static constexpr int kNWarps = kNWarps_;
  649. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  650. static_assert(kNWarps_ == 8);
  651. static constexpr int kBlockM = kBlockM_;
  652. static constexpr int kBlockN = kBlockN_;
  653. static constexpr int kHeadDim = kHeadDim_;
  654. static_assert(kHeadDim % 32 == 0);
  655. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  656. static constexpr int kClusterN = kClusterN_;
  657. using ClusterShape_MNK = Shape<_1, Int<kClusterN>, _1>;
  658. static constexpr int kStages = 2;
  659. static constexpr bool SdP_swapAB = SdP_swapAB_;
  660. static constexpr bool dKV_swapAB = dKV_swapAB_;
  661. static constexpr bool dQ_swapAB = dQ_swapAB_;
  662. static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV
  663. static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS
  664. using TileShapeAtomSdP = std::conditional_t<
  665. !SdP_swapAB,
  666. Shape<Int<kBlockM>, Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kHeadDim>>,
  667. Shape<Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kBlockM>, Int<kHeadDim>>
  668. >;
  669. using AtomLayoutSdP = std::conditional_t<
  670. !SdP_swapAB,
  671. Layout<Shape<Int<AtomLayoutMSdP>, Int<2 / AtomLayoutMSdP>, _1>>,
  672. Layout<Shape<Int<2 / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>
  673. >;
  674. using TiledMmaSdP = decltype(cute::make_tiled_mma(
  675. cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
  676. AtomLayoutSdP{}));
  677. using TileShapeAtomdKV = std::conditional_t<
  678. !dKV_swapAB,
  679. Shape<Int<kBlockN>, Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockM>>,
  680. Shape<Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockN>, Int<kBlockM>>
  681. >;
  682. using AtomLayoutdKV = std::conditional_t<
  683. !dKV_swapAB,
  684. Layout<Shape<Int<AtomLayoutNdKV>, Int<2 / AtomLayoutNdKV>, _1>>,
  685. Layout<Shape<Int<2 / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>
  686. >;
  687. using TiledMmadKV = decltype(cute::make_tiled_mma(
  688. std::conditional_t<
  689. !SdP_swapAB,
  690. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::MN, GMMA::Major::MN>()),
  691. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::K, GMMA::Major::MN>())
  692. >{},
  693. AtomLayoutdKV{}));
  694. using TileShapeAtomdQ = std::conditional_t<
  695. !dQ_swapAB,
  696. Shape<Int<kBlockM>, Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockN>>,
  697. Shape<Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockM>, Int<kBlockN>>
  698. >;
  699. using AtomLayoutdQ = std::conditional_t<
  700. !dQ_swapAB,
  701. Layout<Shape<Int<AtomLayoutMdQ>, Int<2 / AtomLayoutMdQ>, _1>>,
  702. Layout<Shape<Int<2 / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>
  703. >;
  704. static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN;
  705. static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K;
  706. using TiledMmadQ = decltype(cute::make_tiled_mma(
  707. std::conditional_t<
  708. !dQ_swapAB,
  709. std::conditional_t<
  710. Mma_dQ_is_RS,
  711. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>()),
  712. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>())
  713. >,
  714. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::MN, GMMA::Major::K>())
  715. >{},
  716. AtomLayoutdQ{}));
  717. using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
  718. using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
  719. using GmemTiledCopydKV = cute::SM90_TMA_STORE;
  720. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  721. static constexpr bool Has_cp_async = true;
  722. #else
  723. static constexpr bool Has_cp_async = false;
  724. #endif
  725. // For the dot_do_o preprocessing kernel
  726. using Gmem_copy_struct = std::conditional_t<
  727. Has_cp_async,
  728. SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
  729. DefaultCopy
  730. >;
  731. static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
  732. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  733. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
  734. // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
  735. // to affect speed in practice.
  736. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
  737. static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
  738. using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  739. Stride<Int<kGmemThreadsPerRow>, _1>>;
  740. using GmemTiledCopydO = decltype(
  741. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  742. GmemLayoutAtom{},
  743. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  744. using GmemTiledCopydQ = decltype(
  745. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  746. GmemLayoutAtom{},
  747. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  748. using GmemLayoutAtomdQaccum = std::conditional_t<
  749. kBlockKSmem == 32,
  750. Layout<Shape <_32, _8>, // Thread layout, 8 threads per row
  751. Stride< _8, _1>>,
  752. Layout<Shape <_16, _16>, // Thread layout, 16 threads per row
  753. Stride< _16, _1>>
  754. >;
  755. using GmemTiledCopydQaccum = decltype(
  756. make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
  757. GmemLayoutAtomdQaccum{},
  758. Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
  759. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  760. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  761. using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
  762. using SmemLayoutdO = SmemLayoutQ;
  763. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  764. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  765. using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{},
  766. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  767. using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  768. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  769. using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{},
  770. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  771. using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  772. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  773. using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{})));
  774. using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  775. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  776. using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, select<0, 1>(TileShape_MNK{})));
  777. // Note this is the transpose in terms of the view, not in terms of memory.
  778. using SmemLayoutQt =
  779. decltype(cute::composition(SmemLayoutQ{},
  780. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  781. make_stride(Int<kBlockM>{}, _1{}))));
  782. using SmemLayoutdOt =
  783. decltype(cute::composition(SmemLayoutdO{},
  784. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  785. make_stride(Int<kBlockM>{}, _1{}))));
  786. using SmemLayoutKt =
  787. decltype(cute::composition(SmemLayoutK{},
  788. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
  789. make_stride(Int<kBlockN>{}, _1{}, Int<kBlockN * kHeadDim>{}))));
  790. using SmemLayoutPt =
  791. decltype(cute::composition(SmemLayoutP{},
  792. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  793. make_stride(Int<kBlockM>{}, _1{}))));
  794. using SmemLayoutdSt =
  795. decltype(cute::composition(SmemLayoutdS{},
  796. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  797. make_stride(Int<kBlockM>{}, _1{}))));
  798. using SmemLayoutdK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
  799. using SmemLayoutdV = SmemLayoutdK;
  800. using SmemLayoutdKt = SmemLayoutKt;
  801. using SmemLayoutdVt = SmemLayoutKt;
  802. using SmemLayoutdQTMA = decltype(tile_to_shape(SmemLayoutAtomK{}, select<0, 2>(TileShape_MNK{})));
  803. static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
  804. using SmemLayoutAtomdQ = decltype(
  805. composition(Swizzle<kSwizzle, 3, 3>{},
  806. Layout<Shape<_8, Int<kBlockKSmem>>,
  807. Stride<Int<kBlockKSmem>, _1>>{}));
  808. using SmemLayoutdQ = decltype(tile_to_shape(
  809. SmemLayoutAtomdQ{},
  810. make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
  811. using SmemLayoutdQt =
  812. decltype(cute::composition(SmemLayoutdQ{},
  813. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  814. make_stride(Int<kBlockM>{}, _1{}))));
  815. static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
  816. using SmemLayoutAtomdKV = decltype(
  817. composition(Swizzle<kSwizzle, 3, 3>{},
  818. Layout<Shape<_8, Int<kBlockKSmem>>,
  819. Stride<Int<kBlockKSmem>, _1>>{}));
  820. using SmemLayoutdKV = decltype(tile_to_shape(
  821. SmemLayoutAtomdKV{},
  822. make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
  823. using SmemLayoutdKVt =
  824. decltype(cute::composition(SmemLayoutdKV{},
  825. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  826. make_stride(Int<kBlockN>{}, _1{}))));
  827. static constexpr int kSmemdKVSize = size(SmemLayoutdKV{}) * sizeof(Element) * 2;
  828. // using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  829. using SmemCopyAtomPdS = Copy_Atom<
  830. std::conditional_t<!SdP_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  831. Element>;
  832. using SmemCopyAtomdKV = Copy_Atom<
  833. std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  834. Element>;
  835. using SmemCopyAtomdQ = Copy_Atom<
  836. std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  837. Element>;
  838. using SharedStorage = SharedStorageQKVdOdKVSeqqPar<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
  839. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQTMA>;
  840. // using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages * 2>;
  841. // using PipelineState = typename cutlass::PipelineState<kStages * 2>;
  842. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  843. };
  844. ////////////////////////////////////////////////////////////////////////////////////////////////////