kernel_traits.h 50 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/algorithm/copy.hpp"
  6. #include "cute/atom/mma_atom.hpp"
  7. #include "cutlass/gemm/collective/collective_builder.hpp"
  8. #include "cutlass/cutlass.h"
  9. #include "cutlass/layout/layout.h"
  10. #include "cutlass/numeric_types.h"
  11. #include "cutlass/pipeline/pipeline.hpp"
  12. using namespace cute;
  13. template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
  14. class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
  15. struct SharedStorageQKVO {
  16. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
  17. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
  18. union {
  19. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
  20. cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
  21. };
  22. struct {
  23. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  24. cutlass::arch::ClusterBarrier barrier_O;
  25. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  26. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  27. int tile_count_semaphore;
  28. };
  29. };
  30. template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
  31. class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
  32. struct SharedStorageQKVOVt {
  33. struct {
  34. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
  35. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
  36. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
  37. union {
  38. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v_out;
  39. cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
  40. };
  41. };
  42. struct {
  43. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  44. cutlass::arch::ClusterBarrier barrier_O;
  45. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  46. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  47. typename cutlass::PipelineAsync<kStages>::SharedStorage pipeline_vt;
  48. int tile_count_semaphore;
  49. float softmax_scale_qk_log2;
  50. float descale_v;
  51. };
  52. };
  53. // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
  54. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool Is_Q_in_regs_=false,
  55. int kClusterM_ = 1, typename elem_type=cutlass::half_t, bool kUseDocMasking_=false>
  56. struct Flash_fwd_kernel_traits {
  57. using Element = elem_type;
  58. using ElementAccum = float;
  59. using OutputType = elem_type;
  60. using index_t = int64_t;
  61. static constexpr bool kUseDocMasking = kUseDocMasking_;
  62. // The number of threads.
  63. static constexpr int kNWarps = kNWarps_;
  64. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  65. static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarp;
  66. static constexpr bool Is_Q_in_regs = Is_Q_in_regs_;
  67. static_assert(kNWarps_ == 4 || kNWarps_ == 8 || kNWarps_ == 12 || kNWarps_ == 16);
  68. static constexpr bool Is_WS = kNWarps_ >= 12;
  69. static_assert(!(Is_WS && Is_Q_in_regs), "Warp-specialization does not support Q in registers");
  70. static constexpr int kBlockM = kBlockM_;
  71. static constexpr int kBlockN = kBlockN_;
  72. static constexpr int kHeadDim = kHeadDim_;
  73. static_assert(kHeadDim % 32 == 0);
  74. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  75. static constexpr int kClusterM = kClusterM_;
  76. using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
  77. static constexpr int kStages = kStages_;
  78. using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
  79. using TiledMma0 = decltype(cute::make_tiled_mma(
  80. std::conditional_t<
  81. Is_Q_in_regs,
  82. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>()),
  83. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>())
  84. >{},
  85. AtomLayoutMNK{}));
  86. using TiledMma1 = decltype(cute::make_tiled_mma(
  87. cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{})),
  88. GMMA::Major::K, GMMA::Major::MN>(),
  89. AtomLayoutMNK{}));
  90. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  91. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  92. using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
  93. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  94. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  95. using SmemLayoutK =
  96. decltype(tile_to_shape(SmemLayoutAtomK{},
  97. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  98. using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  99. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  100. using SmemLayoutV =
  101. decltype(tile_to_shape(SmemLayoutAtomV{},
  102. make_shape(get<1>(TileShape_MNK{}), get<2>(TileShape_MNK{}), Int<kStages>{})));
  103. // Note this is the transpose in terms of the view, not in terms of memory.
  104. using SmemLayoutVt =
  105. decltype(composition(SmemLayoutV{},
  106. make_ordered_layout(
  107. make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
  108. Step<_2, _1, _3>{})));
  109. using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, OutputType,
  110. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  111. using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
  112. using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  113. using SharedStorage = SharedStorageQKVO<kStages, Element, Element, Element, SmemLayoutQ,
  114. SmemLayoutK, SmemLayoutV, SmemLayoutO>;
  115. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  116. using MainloopPipelineNoTMA = typename cutlass::PipelineAsync<kStages>;
  117. using PipelineState = typename cutlass::PipelineState<kStages>;
  118. // using BarrierType = typename MainloopPipeline::ProducerBarrierType;
  119. };
  120. // Traits struct for fp8 kernel with in-kernel transpose
  121. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool Is_Q_in_regs_=false,
  122. int kClusterM_ = 1, typename elem_type=cutlass::float_e4m3_t, bool kUseDocMasking_ = false>
  123. struct Flash_fwd_kernel_traits_fp8 {
  124. using Element = elem_type;
  125. static_assert(cutlass::sizeof_bits_v<Element> == 8);
  126. using ElementAccum = float;
  127. using OutputType = cutlass::half_t;
  128. using index_t = int64_t;
  129. static constexpr bool kUseDocMasking = kUseDocMasking_;
  130. // The number of threads.
  131. static constexpr int kNWarps = kNWarps_;
  132. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  133. static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
  134. static constexpr bool Is_Q_in_regs = Is_Q_in_regs_;
  135. static_assert(kNWarps_ == 12 || kNWarps_ == 16);
  136. static constexpr bool Is_WS = true;
  137. static_assert(!Is_Q_in_regs, "Warp-specialization does not support Q in registers");
  138. static constexpr int kBlockM = kBlockM_;
  139. static constexpr int kBlockN = kBlockN_;
  140. static constexpr int kHeadDim = kHeadDim_;
  141. static_assert(kHeadDim % 32 == 0);
  142. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  143. static constexpr int kClusterM = kClusterM_;
  144. using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
  145. static constexpr int kStages = kStages_;
  146. static_assert(kStages > 1);
  147. using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
  148. using TiledMma0 = decltype(cute::make_tiled_mma(
  149. cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
  150. AtomLayoutMNK{}));
  151. using TiledMma1 = decltype(cute::make_tiled_mma(
  152. cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{}))>(),
  153. AtomLayoutMNK{}));
  154. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  155. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  156. using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
  157. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  158. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  159. using SmemLayoutK =
  160. decltype(tile_to_shape(SmemLayoutAtomK{},
  161. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  162. using TransposeShapeAtomV = Shape<_64, _64>;
  163. using SmemLayoutAtomV = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
  164. using SmemLayoutV =
  165. decltype(tile_to_shape(SmemLayoutAtomV{},
  166. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  167. // for fp8 in-kernel transpose -- src layout
  168. using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{}));
  169. using SmemShapeLDSM = Shape<Shape<_8, _8>, Shape<_16, _4>>;
  170. using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{},
  171. shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{}), shape<3>(SmemLayoutDivideV{})));
  172. using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{})));
  173. // For fp8, this is the memory transpose.
  174. using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
  175. using SmemLayoutVt =
  176. decltype(tile_to_shape(SmemLayoutAtomVt{},
  177. make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{})));
  178. // for fp8 in-kernel transpose -- dst layout
  179. using SmemLayoutVtTrans =
  180. decltype(composition(SmemLayoutVt{},
  181. make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1, _3>{})));
  182. using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{}));
  183. #ifndef NO_FP8_COLUMN_PERMUTE
  184. using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_8, _8>>;
  185. #else
  186. using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_16, _4>>;
  187. #endif
  188. using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{},
  189. shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{}), shape<3>(SmemLayoutDivideVt{})));
  190. using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{})));
  191. using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, OutputType,
  192. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  193. using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
  194. // used for rmem -> smem O copy in fp8 kernel to undo column permutation
  195. using ThreadLayoutrO = Layout<Shape<_8, Int<kBlockM/16>, _4, _1>,
  196. Stride<_4, _32, _1, _0>>;
  197. using ValueLayoutrO = Layout<Shape<_1, _2, Shape<_2, _2>, Int<kHeadDim/16>>,
  198. Stride<_0, _2, Stride<_4, _1>, _8>>;
  199. using TiledCopyrO = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<uint16_t>, OutputType>{},
  200. ThreadLayoutrO{}, ValueLayoutrO{}));
  201. using TiledCopyShaperO = Shape<_8, Int<kBlockM/8>, _16, Int<kHeadDim/16>>;
  202. using SmemLayoutrO = decltype(composition(SmemLayoutO{}, Layout<TiledCopyShaperO>{}));
  203. using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  204. using SharedStorage = SharedStorageQKVOVt<kStages, Element, Element, OutputType, SmemLayoutQ,
  205. SmemLayoutK, SmemLayoutV, SmemLayoutO>;
  206. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  207. using MainloopPipelineNoTMA = typename cutlass::PipelineAsync<kStages>;
  208. using PipelineState = typename cutlass::PipelineState<kStages>;
  209. // using BarrierType = typename MainloopPipeline::ProducerBarrierType;
  210. };
  211. ////////////////////////////////////////////////////////////////////////////////////////////////////
  212. template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  213. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  214. class SmemLayoutdK, class SmemLayoutdV>
  215. struct SharedStorageQKVdOdKV;
  216. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  217. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  218. class SmemLayoutdK, class SmemLayoutdV>
  219. struct SharedStorageQKVdOdKV<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  220. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV> {
  221. struct {
  222. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  223. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  224. union {
  225. struct {
  226. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  227. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  228. };
  229. struct {
  230. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  231. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  232. };
  233. };
  234. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  235. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  236. };
  237. struct {
  238. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  239. cutlass::arch::ClusterTransactionBarrier barrier_K;
  240. cutlass::arch::ClusterTransactionBarrier barrier_V;
  241. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  242. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  243. };
  244. };
  245. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  246. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  247. class SmemLayoutdK, class SmemLayoutdV>
  248. struct SharedStorageQKVdOdKV<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  249. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV> {
  250. struct {
  251. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  252. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  253. union {
  254. struct {
  255. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  256. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  257. };
  258. struct {
  259. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  260. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  261. };
  262. };
  263. union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
  264. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  265. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  266. };
  267. };
  268. struct {
  269. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  270. cutlass::arch::ClusterTransactionBarrier barrier_K;
  271. cutlass::arch::ClusterTransactionBarrier barrier_V;
  272. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  273. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  274. };
  275. };
  276. template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  277. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
  278. class SmemLayoutdK, class SmemLayoutdV>
  279. struct SharedStorageQKVdOdKVWS;
  280. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  281. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
  282. class SmemLayoutdK, class SmemLayoutdV>
  283. struct SharedStorageQKVdOdKVWS<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  284. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV> {
  285. struct {
  286. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  287. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  288. union {
  289. struct {
  290. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  291. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  292. };
  293. struct {
  294. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  295. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  296. };
  297. };
  298. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  299. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  300. cute::array_aligned<float, cute::cosize_v<SmemLayoutdQacc>> smem_dqacc;
  301. cute::array_aligned<float, 128> smem_lse;
  302. cute::array_aligned<float, 128> smem_dpsum;
  303. };
  304. struct {
  305. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  306. cutlass::arch::ClusterTransactionBarrier barrier_K;
  307. cutlass::arch::ClusterTransactionBarrier barrier_V;
  308. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  309. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  310. };
  311. };
  312. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  313. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
  314. class SmemLayoutdK, class SmemLayoutdV>
  315. struct SharedStorageQKVdOdKVWS<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  316. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV> {
  317. struct {
  318. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  319. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  320. union {
  321. struct {
  322. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  323. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  324. };
  325. struct {
  326. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  327. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  328. };
  329. };
  330. union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
  331. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  332. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  333. };
  334. cute::array_aligned<float, cute::cosize_v<SmemLayoutdQacc>> smem_dqacc;
  335. cute::array_aligned<float, 128> smem_lse;
  336. cute::array_aligned<float, 128> smem_dpsum;
  337. };
  338. struct {
  339. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  340. cutlass::arch::ClusterTransactionBarrier barrier_K;
  341. cutlass::arch::ClusterTransactionBarrier barrier_V;
  342. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  343. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  344. };
  345. };
  346. template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  347. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  348. class SmemLayoutdQ>
  349. struct SharedStorageQKVdOdKVSeqqPar;
  350. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  351. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  352. class SmemLayoutdQ>
  353. struct SharedStorageQKVdOdKVSeqqPar<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  354. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQ> {
  355. struct {
  356. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  357. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  358. union {
  359. struct {
  360. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  361. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  362. };
  363. struct {
  364. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
  365. };
  366. };
  367. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  368. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  369. };
  370. struct {
  371. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  372. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  373. cutlass::arch::ClusterTransactionBarrier barrier_dO;
  374. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  375. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  376. };
  377. };
  378. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  379. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  380. class SmemLayoutdQ>
  381. struct SharedStorageQKVdOdKVSeqqPar<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  382. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQ> {
  383. struct {
  384. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  385. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  386. union {
  387. struct {
  388. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  389. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  390. };
  391. struct {
  392. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
  393. };
  394. };
  395. union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
  396. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  397. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  398. };
  399. };
  400. struct {
  401. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  402. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  403. cutlass::arch::ClusterTransactionBarrier barrier_dO;
  404. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  405. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  406. };
  407. };
  408. ////////////////////////////////////////////////////////////////////////////////////////////////////
  409. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
  410. bool SdP_swapAB_, bool dKV_swapAB_, bool dQ_swapAB_,
  411. int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
  412. int kClusterN_ = 1, typename elem_type=cutlass::half_t>
  413. struct Flash_bwd_kernel_traits {
  414. using Element = elem_type;
  415. using ElementAccum = float;
  416. using index_t = int64_t;
  417. // The number of threads.
  418. static constexpr int kNWarps = kNWarps_;
  419. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  420. static constexpr int kNThreadsNonWS = 8 * cutlass::NumThreadsPerWarp;
  421. // static constexpr int kNThreadsdQ = cutlass::NumThreadsPerWarpGroup;
  422. static constexpr int kNThreadsdQ = 2 * cutlass::NumThreadsPerWarpGroup;
  423. static_assert(kNWarps_ == 8 || kNWarps_ == 12);
  424. static constexpr bool Is_WS = kNWarps_ >= 12;
  425. static constexpr int kBlockM = kBlockM_;
  426. static constexpr int kBlockN = kBlockN_;
  427. static constexpr int kHeadDim = kHeadDim_;
  428. static_assert(kHeadDim % 32 == 0);
  429. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  430. static constexpr int kClusterN = kClusterN_;
  431. using ClusterShape_MNK = Shape<_1, Int<kClusterN>, _1>;
  432. static constexpr int kStages = 2;
  433. static constexpr bool SdP_swapAB = SdP_swapAB_;
  434. static constexpr bool dKV_swapAB = dKV_swapAB_;
  435. static constexpr bool dQ_swapAB = dQ_swapAB_;
  436. static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV
  437. static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS
  438. using TileShapeAtomSdP = std::conditional_t<
  439. !SdP_swapAB,
  440. Shape<Int<kBlockM>, Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kHeadDim>>,
  441. Shape<Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kBlockM>, Int<kHeadDim>>
  442. >;
  443. using AtomLayoutSdP = std::conditional_t<
  444. !SdP_swapAB,
  445. Layout<Shape<Int<AtomLayoutMSdP>, Int<2 / AtomLayoutMSdP>, _1>>,
  446. Layout<Shape<Int<2 / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>
  447. >;
  448. using TiledMmaSdP = decltype(cute::make_tiled_mma(
  449. cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
  450. AtomLayoutSdP{}));
  451. using TileShapeAtomdKV = std::conditional_t<
  452. !dKV_swapAB,
  453. Shape<Int<kBlockN>, Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockM>>,
  454. Shape<Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockN>, Int<kBlockM>>
  455. >;
  456. using AtomLayoutdKV = std::conditional_t<
  457. !dKV_swapAB,
  458. Layout<Shape<Int<AtomLayoutNdKV>, Int<2 / AtomLayoutNdKV>, _1>>,
  459. Layout<Shape<Int<2 / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>
  460. >;
  461. using TiledMmadKV = decltype(cute::make_tiled_mma(
  462. std::conditional_t<
  463. !SdP_swapAB,
  464. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::MN, GMMA::Major::MN>()),
  465. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::K, GMMA::Major::MN>())
  466. >{},
  467. AtomLayoutdKV{}));
  468. using TileShapeAtomdQ = std::conditional_t<
  469. !dQ_swapAB,
  470. Shape<Int<kBlockM>, Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockN>>,
  471. Shape<Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockM>, Int<kBlockN>>
  472. // Shape<Int<kBlockM>, Int<kHeadDim >, Int<kBlockN>>,
  473. // Shape<Int<kHeadDim>, Int<kBlockM>, Int<kBlockN>>
  474. >;
  475. using AtomLayoutdQ = std::conditional_t<
  476. !dQ_swapAB,
  477. Layout<Shape<Int<AtomLayoutMdQ>, Int<2 / AtomLayoutMdQ>, _1>>,
  478. Layout<Shape<Int<2 / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>
  479. // Layout<Shape<Int<1>, Int<1>, _1>>,
  480. // Layout<Shape<Int<1>, Int<1>, _1>>
  481. >;
  482. static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN;
  483. static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K;
  484. using TiledMmadQ = decltype(cute::make_tiled_mma(
  485. std::conditional_t<
  486. !dQ_swapAB,
  487. std::conditional_t<
  488. Mma_dQ_is_RS,
  489. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>()),
  490. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>())
  491. >,
  492. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::MN, GMMA::Major::K>())
  493. >{},
  494. AtomLayoutdQ{}));
  495. using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
  496. using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
  497. using GmemTiledCopydKV = cute::SM90_TMA_STORE;
  498. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  499. static constexpr bool Has_cp_async = true;
  500. #else
  501. static constexpr bool Has_cp_async = false;
  502. #endif
  503. // For the dot_do_o preprocessing kernel
  504. using Gmem_copy_struct = std::conditional_t<
  505. Has_cp_async,
  506. SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
  507. DefaultCopy
  508. >;
  509. static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
  510. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  511. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
  512. // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
  513. // to affect speed in practice.
  514. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
  515. static_assert(kNThreadsNonWS % kGmemThreadsPerRow == 0, "kNThreadsNonWS must be a multiple of kGmemThreadsPerRow");
  516. using GmemLayoutAtom = Layout<Shape <Int<kNThreadsNonWS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  517. Stride<Int<kGmemThreadsPerRow>, _1>>;
  518. using GmemLayoutAtomdQ = Layout<Shape <Int<kNThreadsdQ / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  519. Stride<Int<kGmemThreadsPerRow>, _1>>;
  520. using GmemTiledCopydO = decltype(
  521. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  522. GmemLayoutAtom{},
  523. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  524. using GmemTiledCopydQ = decltype(
  525. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  526. GmemLayoutAtomdQ{},
  527. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  528. using GmemLayoutAtomdQaccum = std::conditional_t<
  529. kBlockKSmem == 32,
  530. Layout<Shape <Int<kNThreadsdQ / 8>, _8>, // Thread layout, 8 threads per row
  531. Stride< _8, _1>>,
  532. Layout<Shape <Int<kNThreadsdQ / 16>, _16>, // Thread layout, 16 threads per row
  533. Stride< _16, _1>>
  534. >;
  535. using GmemTiledCopydQaccum = decltype(
  536. make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
  537. GmemLayoutAtomdQaccum{},
  538. Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
  539. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  540. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  541. using SmemLayoutQ =
  542. decltype(tile_to_shape(SmemLayoutAtomQ{},
  543. make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  544. using SmemLayoutdO = SmemLayoutQ;
  545. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  546. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  547. using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
  548. using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  549. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  550. using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{})));
  551. using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  552. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  553. using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{})));
  554. using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  555. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  556. using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, select<0, 1>(TileShape_MNK{})));
  557. // using SmemLayoutAtomdQacc = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, ElementAccum,
  558. // decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  559. // using SmemLayoutdQacc = decltype(tile_to_shape(SmemLayoutAtomdQacc{}, select<0, 2>(TileShape_MNK{})));
  560. // Note this is the transpose in terms of the view, not in terms of memory.
  561. using SmemLayoutQt =
  562. decltype(cute::composition(SmemLayoutQ{},
  563. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
  564. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
  565. using SmemLayoutdOt =
  566. decltype(cute::composition(SmemLayoutdO{},
  567. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
  568. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
  569. using SmemLayoutKt =
  570. decltype(cute::composition(SmemLayoutK{},
  571. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  572. make_stride(Int<kBlockN>{}, _1{}))));
  573. using SmemLayoutPt =
  574. decltype(cute::composition(SmemLayoutP{},
  575. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  576. make_stride(Int<kBlockM>{}, _1{}))));
  577. using SmemLayoutdSt =
  578. decltype(cute::composition(SmemLayoutdS{},
  579. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  580. make_stride(Int<kBlockM>{}, _1{}))));
  581. // using SmemLayoutdQacct =
  582. // decltype(cute::composition(SmemLayoutdQacc{},
  583. // make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  584. // make_stride(Int<kBlockM>{}, _1{}))));
  585. using SmemLayoutdK = SmemLayoutK;
  586. using SmemLayoutdV = SmemLayoutV;
  587. using SmemLayoutdKt = SmemLayoutKt;
  588. using SmemLayoutdVt = SmemLayoutKt;
  589. static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
  590. using SmemLayoutAtomdQ = decltype(
  591. // composition(Swizzle<kSwizzle, 3, 3>{},
  592. composition(Swizzle<3, 3, 3>{},
  593. Layout<Shape<Int<kNThreadsdQ / 32>, Int<32>>,
  594. Stride<Int<32>, _1>>{}));
  595. using SmemLayoutdQ = decltype(tile_to_shape(
  596. SmemLayoutAtomdQ{},
  597. make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
  598. using SmemLayoutdQt =
  599. decltype(cute::composition(SmemLayoutdQ{},
  600. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  601. make_stride(Int<kBlockM>{}, _1{}))));
  602. static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
  603. using SmemLayoutAtomdQaccTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, ElementAccum,
  604. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  605. using SmemLayoutdQaccTMA = decltype(tile_to_shape(SmemLayoutAtomdQaccTMA{}, select<0, 2>(TileShape_MNK{})));
  606. using SmemLayoutdQacc = SmemLayoutdQ;
  607. using SmemLayoutdQacct = SmemLayoutdQt;
  608. using SmemLayoutdQacc2 = decltype(tile_to_shape(
  609. SmemLayoutAtomdQ{},
  610. make_shape(Int<kBlockM>{}, Int<kHeadDim>{}, _2{})));
  611. // using SmemLayoutdQacc = decltype(tile_to_shape(SmemLayoutAtomdQacc{}, select<0, 2>(TileShape_MNK{})));
  612. // using SmemLayoutdQacct =
  613. // decltype(cute::composition(SmemLayoutdQacc{},
  614. // make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  615. // make_stride(Int<kBlockM>{}, _1{}))));
  616. using RmemTiledCopydQacc = decltype(
  617. make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
  618. GmemLayoutAtomdQaccum{},
  619. Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
  620. // using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  621. using SmemCopyAtomPdS = Copy_Atom<
  622. std::conditional_t<!SdP_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  623. Element>;
  624. using SmemCopyAtomdKV = Copy_Atom<
  625. std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  626. Element>;
  627. using SmemCopyAtomdQ = Copy_Atom<
  628. std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  629. Element>;
  630. using SharedStorage = std::conditional_t<
  631. !Is_WS,
  632. SharedStorageQKVdOdKV<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
  633. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV>,
  634. SharedStorageQKVdOdKVWS<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
  635. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV>
  636. // SmemLayoutK, SmemLayoutV, SmemLayoutdS, SmemLayoutdQacc2, SmemLayoutdK, SmemLayoutdV>
  637. >;
  638. // using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages * 2>;
  639. // using PipelineState = typename cutlass::PipelineState<kStages * 2>;
  640. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  641. };
  642. ////////////////////////////////////////////////////////////////////////////////////////////////////
  643. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
  644. bool SdP_swapAB_, bool dKV_swapAB_, bool dQ_swapAB_,
  645. int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
  646. int kClusterN_ = 1, typename elem_type=cutlass::half_t>
  647. struct Flash_bwd_seqqpar_kernel_traits {
  648. using Element = elem_type;
  649. using ElementAccum = float;
  650. using index_t = int64_t;
  651. // The number of threads.
  652. static constexpr int kNWarps = kNWarps_;
  653. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  654. static_assert(kNWarps_ == 8);
  655. static constexpr int kBlockM = kBlockM_;
  656. static constexpr int kBlockN = kBlockN_;
  657. static constexpr int kHeadDim = kHeadDim_;
  658. static_assert(kHeadDim % 32 == 0);
  659. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  660. static constexpr int kClusterN = kClusterN_;
  661. using ClusterShape_MNK = Shape<_1, Int<kClusterN>, _1>;
  662. static constexpr int kStages = 2;
  663. static constexpr bool SdP_swapAB = SdP_swapAB_;
  664. static constexpr bool dKV_swapAB = dKV_swapAB_;
  665. static constexpr bool dQ_swapAB = dQ_swapAB_;
  666. static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV
  667. static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS
  668. using TileShapeAtomSdP = std::conditional_t<
  669. !SdP_swapAB,
  670. Shape<Int<kBlockM>, Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kHeadDim>>,
  671. Shape<Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kBlockM>, Int<kHeadDim>>
  672. >;
  673. using AtomLayoutSdP = std::conditional_t<
  674. !SdP_swapAB,
  675. Layout<Shape<Int<AtomLayoutMSdP>, Int<2 / AtomLayoutMSdP>, _1>>,
  676. Layout<Shape<Int<2 / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>
  677. >;
  678. using TiledMmaSdP = decltype(cute::make_tiled_mma(
  679. cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
  680. AtomLayoutSdP{}));
  681. using TileShapeAtomdKV = std::conditional_t<
  682. !dKV_swapAB,
  683. Shape<Int<kBlockN>, Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockM>>,
  684. Shape<Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockN>, Int<kBlockM>>
  685. >;
  686. using AtomLayoutdKV = std::conditional_t<
  687. !dKV_swapAB,
  688. Layout<Shape<Int<AtomLayoutNdKV>, Int<2 / AtomLayoutNdKV>, _1>>,
  689. Layout<Shape<Int<2 / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>
  690. >;
  691. using TiledMmadKV = decltype(cute::make_tiled_mma(
  692. std::conditional_t<
  693. !SdP_swapAB,
  694. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::MN, GMMA::Major::MN>()),
  695. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::K, GMMA::Major::MN>())
  696. >{},
  697. AtomLayoutdKV{}));
  698. using TileShapeAtomdQ = std::conditional_t<
  699. !dQ_swapAB,
  700. Shape<Int<kBlockM>, Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockN>>,
  701. Shape<Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockM>, Int<kBlockN>>
  702. >;
  703. using AtomLayoutdQ = std::conditional_t<
  704. !dQ_swapAB,
  705. Layout<Shape<Int<AtomLayoutMdQ>, Int<2 / AtomLayoutMdQ>, _1>>,
  706. Layout<Shape<Int<2 / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>
  707. >;
  708. static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN;
  709. static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K;
  710. using TiledMmadQ = decltype(cute::make_tiled_mma(
  711. std::conditional_t<
  712. !dQ_swapAB,
  713. std::conditional_t<
  714. Mma_dQ_is_RS,
  715. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>()),
  716. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>())
  717. >,
  718. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::MN, GMMA::Major::K>())
  719. >{},
  720. AtomLayoutdQ{}));
  721. using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
  722. using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
  723. using GmemTiledCopydKV = cute::SM90_TMA_STORE;
  724. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  725. static constexpr bool Has_cp_async = true;
  726. #else
  727. static constexpr bool Has_cp_async = false;
  728. #endif
  729. // For the dot_do_o preprocessing kernel
  730. using Gmem_copy_struct = std::conditional_t<
  731. Has_cp_async,
  732. SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
  733. DefaultCopy
  734. >;
  735. static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
  736. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  737. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
  738. // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
  739. // to affect speed in practice.
  740. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
  741. static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
  742. using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  743. Stride<Int<kGmemThreadsPerRow>, _1>>;
  744. using GmemTiledCopydO = decltype(
  745. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  746. GmemLayoutAtom{},
  747. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  748. using GmemTiledCopydQ = decltype(
  749. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  750. GmemLayoutAtom{},
  751. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  752. using GmemLayoutAtomdQaccum = std::conditional_t<
  753. kBlockKSmem == 32,
  754. Layout<Shape <_32, _8>, // Thread layout, 8 threads per row
  755. Stride< _8, _1>>,
  756. Layout<Shape <_16, _16>, // Thread layout, 16 threads per row
  757. Stride< _16, _1>>
  758. >;
  759. using GmemTiledCopydQaccum = decltype(
  760. make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
  761. GmemLayoutAtomdQaccum{},
  762. Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
  763. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  764. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  765. using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
  766. using SmemLayoutdO = SmemLayoutQ;
  767. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  768. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  769. using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{},
  770. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  771. using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  772. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  773. using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{},
  774. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  775. using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  776. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  777. using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{})));
  778. using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  779. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  780. using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, select<0, 1>(TileShape_MNK{})));
  781. // Note this is the transpose in terms of the view, not in terms of memory.
  782. using SmemLayoutQt =
  783. decltype(cute::composition(SmemLayoutQ{},
  784. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  785. make_stride(Int<kBlockM>{}, _1{}))));
  786. using SmemLayoutdOt =
  787. decltype(cute::composition(SmemLayoutdO{},
  788. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  789. make_stride(Int<kBlockM>{}, _1{}))));
  790. using SmemLayoutKt =
  791. decltype(cute::composition(SmemLayoutK{},
  792. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
  793. make_stride(Int<kBlockN>{}, _1{}, Int<kBlockN * kHeadDim>{}))));
  794. using SmemLayoutPt =
  795. decltype(cute::composition(SmemLayoutP{},
  796. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  797. make_stride(Int<kBlockM>{}, _1{}))));
  798. using SmemLayoutdSt =
  799. decltype(cute::composition(SmemLayoutdS{},
  800. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  801. make_stride(Int<kBlockM>{}, _1{}))));
  802. using SmemLayoutdK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
  803. using SmemLayoutdV = SmemLayoutdK;
  804. using SmemLayoutdKt = SmemLayoutKt;
  805. using SmemLayoutdVt = SmemLayoutKt;
  806. using SmemLayoutdQTMA = decltype(tile_to_shape(SmemLayoutAtomK{}, select<0, 2>(TileShape_MNK{})));
  807. static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
  808. using SmemLayoutAtomdQ = decltype(
  809. composition(Swizzle<kSwizzle, 3, 3>{},
  810. Layout<Shape<_8, Int<kBlockKSmem>>,
  811. Stride<Int<kBlockKSmem>, _1>>{}));
  812. using SmemLayoutdQ = decltype(tile_to_shape(
  813. SmemLayoutAtomdQ{},
  814. make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
  815. using SmemLayoutdQt =
  816. decltype(cute::composition(SmemLayoutdQ{},
  817. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  818. make_stride(Int<kBlockM>{}, _1{}))));
  819. static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
  820. using SmemLayoutAtomdKV = decltype(
  821. composition(Swizzle<kSwizzle, 3, 3>{},
  822. Layout<Shape<_8, Int<kBlockKSmem>>,
  823. Stride<Int<kBlockKSmem>, _1>>{}));
  824. using SmemLayoutdKV = decltype(tile_to_shape(
  825. SmemLayoutAtomdKV{},
  826. make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
  827. using SmemLayoutdKVt =
  828. decltype(cute::composition(SmemLayoutdKV{},
  829. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  830. make_stride(Int<kBlockN>{}, _1{}))));
  831. static constexpr int kSmemdKVSize = size(SmemLayoutdKV{}) * sizeof(Element) * 2;
  832. // using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  833. using SmemCopyAtomPdS = Copy_Atom<
  834. std::conditional_t<!SdP_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  835. Element>;
  836. using SmemCopyAtomdKV = Copy_Atom<
  837. std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  838. Element>;
  839. using SmemCopyAtomdQ = Copy_Atom<
  840. std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  841. Element>;
  842. using SharedStorage = SharedStorageQKVdOdKVSeqqPar<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
  843. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQTMA>;
  844. // using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages * 2>;
  845. // using PipelineState = typename cutlass::PipelineState<kStages * 2>;
  846. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  847. };
  848. ////////////////////////////////////////////////////////////////////////////////////////////////////