kernel_traits.h 50 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include "cute/algorithm/copy.hpp"
  6. #include "cute/atom/mma_atom.hpp"
  7. #include "cutlass/gemm/collective/collective_builder.hpp"
  8. #include "cutlass/cutlass.h"
  9. #include "cutlass/layout/layout.h"
  10. #include "cutlass/numeric_types.h"
  11. #include "cutlass/pipeline/pipeline.hpp"
  12. using namespace cute;
  13. template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
  14. class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
  15. struct SharedStorageQKVO {
  16. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
  17. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
  18. union {
  19. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
  20. cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
  21. };
  22. struct {
  23. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  24. cutlass::arch::ClusterBarrier barrier_O;
  25. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  26. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  27. int tile_count_semaphore;
  28. };
  29. };
  30. template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
  31. class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
  32. struct SharedStorageQKVOVt {
  33. struct {
  34. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
  35. cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
  36. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
  37. union {
  38. cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v_out;
  39. cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
  40. };
  41. };
  42. struct {
  43. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  44. cutlass::arch::ClusterBarrier barrier_O;
  45. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  46. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  47. typename cutlass::PipelineAsync<kStages>::SharedStorage pipeline_vt;
  48. int tile_count_semaphore;
  49. float softmax_scale_qk_log2;
  50. float descale_v;
  51. };
  52. };
  53. // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
  54. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool Is_Q_in_regs_=false,
  55. int kClusterM_ = 1, typename elem_type=cutlass::half_t>
  56. struct Flash_fwd_kernel_traits {
  57. using Element = elem_type;
  58. using ElementAccum = float;
  59. using OutputType = elem_type;
  60. using index_t = int64_t;
  61. // The number of threads.
  62. static constexpr int kNWarps = kNWarps_;
  63. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  64. static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarp;
  65. static constexpr bool Is_Q_in_regs = Is_Q_in_regs_;
  66. static_assert(kNWarps_ == 4 || kNWarps_ == 8 || kNWarps_ == 12 || kNWarps_ == 16);
  67. static constexpr bool Is_WS = kNWarps_ >= 12;
  68. static_assert(!(Is_WS && Is_Q_in_regs), "Warp-specialization does not support Q in registers");
  69. static constexpr int kBlockM = kBlockM_;
  70. static constexpr int kBlockN = kBlockN_;
  71. static constexpr int kHeadDim = kHeadDim_;
  72. static_assert(kHeadDim % 32 == 0);
  73. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  74. static constexpr int kClusterM = kClusterM_;
  75. using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
  76. static constexpr int kStages = kStages_;
  77. using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
  78. using TiledMma0 = decltype(cute::make_tiled_mma(
  79. std::conditional_t<
  80. Is_Q_in_regs,
  81. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShape_MNK>()),
  82. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>())
  83. >{},
  84. AtomLayoutMNK{}));
  85. using TiledMma1 = decltype(cute::make_tiled_mma(
  86. cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{})),
  87. GMMA::Major::K, GMMA::Major::MN>(),
  88. AtomLayoutMNK{}));
  89. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  90. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  91. using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
  92. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  93. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  94. using SmemLayoutK =
  95. decltype(tile_to_shape(SmemLayoutAtomK{},
  96. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  97. using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  98. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  99. using SmemLayoutV =
  100. decltype(tile_to_shape(SmemLayoutAtomV{},
  101. make_shape(get<1>(TileShape_MNK{}), get<2>(TileShape_MNK{}), Int<kStages>{})));
  102. // Note this is the transpose in terms of the view, not in terms of memory.
  103. using SmemLayoutVt =
  104. decltype(composition(SmemLayoutV{},
  105. make_ordered_layout(
  106. make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
  107. Step<_2, _1, _3>{})));
  108. using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, OutputType,
  109. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  110. using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
  111. using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  112. using SharedStorage = SharedStorageQKVO<kStages, Element, Element, Element, SmemLayoutQ,
  113. SmemLayoutK, SmemLayoutV, SmemLayoutO>;
  114. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  115. using MainloopPipelineNoTMA = typename cutlass::PipelineAsync<kStages>;
  116. using PipelineState = typename cutlass::PipelineState<kStages>;
  117. // using BarrierType = typename MainloopPipeline::ProducerBarrierType;
  118. };
  119. // Traits struct for fp8 kernel with in-kernel transpose
  120. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool Is_Q_in_regs_=false,
  121. int kClusterM_ = 1, typename elem_type=cutlass::float_e4m3_t>
  122. struct Flash_fwd_kernel_traits_fp8 {
  123. using Element = elem_type;
  124. static_assert(cutlass::sizeof_bits_v<Element> == 8);
  125. using ElementAccum = float;
  126. using OutputType = cutlass::half_t;
  127. using index_t = int64_t;
  128. // The number of threads.
  129. static constexpr int kNWarps = kNWarps_;
  130. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  131. static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
  132. static constexpr bool Is_Q_in_regs = Is_Q_in_regs_;
  133. static_assert(kNWarps_ == 12 || kNWarps_ == 16);
  134. static constexpr bool Is_WS = true;
  135. static_assert(!Is_Q_in_regs, "Warp-specialization does not support Q in registers");
  136. static constexpr int kBlockM = kBlockM_;
  137. static constexpr int kBlockN = kBlockN_;
  138. static constexpr int kHeadDim = kHeadDim_;
  139. static_assert(kHeadDim % 32 == 0);
  140. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  141. static constexpr int kClusterM = kClusterM_;
  142. using ClusterShape_MNK = Shape<Int<kClusterM>, _1, _1>;
  143. static constexpr int kStages = kStages_;
  144. static_assert(kStages > 1);
  145. using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
  146. using TiledMma0 = decltype(cute::make_tiled_mma(
  147. cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
  148. AtomLayoutMNK{}));
  149. using TiledMma1 = decltype(cute::make_tiled_mma(
  150. cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{}))>(),
  151. AtomLayoutMNK{}));
  152. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  153. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  154. using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
  155. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  156. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  157. using SmemLayoutK =
  158. decltype(tile_to_shape(SmemLayoutAtomK{},
  159. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  160. using TransposeShapeAtomV = Shape<_64, _64>;
  161. using SmemLayoutAtomV = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
  162. using SmemLayoutV =
  163. decltype(tile_to_shape(SmemLayoutAtomV{},
  164. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  165. // for fp8 in-kernel transpose -- src layout
  166. using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{}));
  167. using SmemShapeLDSM = Shape<Shape<_8, _8>, Shape<_16, _4>>;
  168. using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{},
  169. shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{}), shape<3>(SmemLayoutDivideV{})));
  170. using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{})));
  171. // For fp8, this is the memory transpose.
  172. using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
  173. using SmemLayoutVt =
  174. decltype(tile_to_shape(SmemLayoutAtomVt{},
  175. make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{})));
  176. // for fp8 in-kernel transpose -- dst layout
  177. using SmemLayoutVtTrans =
  178. decltype(composition(SmemLayoutVt{},
  179. make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1, _3>{})));
  180. using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{}));
  181. #ifndef NO_FP8_COLUMN_PERMUTE
  182. using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_8, _8>>;
  183. #else
  184. using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_16, _4>>;
  185. #endif
  186. using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{},
  187. shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{}), shape<3>(SmemLayoutDivideVt{})));
  188. using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{})));
  189. using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, OutputType,
  190. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  191. using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
  192. // used for rmem -> smem O copy in fp8 kernel to undo column permutation
  193. using ThreadLayoutrO = Layout<Shape<_8, Int<kBlockM/16>, _4, _1>,
  194. Stride<_4, _32, _1, _0>>;
  195. using ValueLayoutrO = Layout<Shape<_1, _2, Shape<_2, _2>, Int<kHeadDim/16>>,
  196. Stride<_0, _2, Stride<_4, _1>, _8>>;
  197. using TiledCopyrO = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<uint16_t>, OutputType>{},
  198. ThreadLayoutrO{}, ValueLayoutrO{}));
  199. using TiledCopyShaperO = Shape<_8, Int<kBlockM/8>, _16, Int<kHeadDim/16>>;
  200. using SmemLayoutrO = decltype(composition(SmemLayoutO{}, Layout<TiledCopyShaperO>{}));
  201. using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  202. using SharedStorage = SharedStorageQKVOVt<kStages, Element, Element, OutputType, SmemLayoutQ,
  203. SmemLayoutK, SmemLayoutV, SmemLayoutO>;
  204. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  205. using MainloopPipelineNoTMA = typename cutlass::PipelineAsync<kStages>;
  206. using PipelineState = typename cutlass::PipelineState<kStages>;
  207. // using BarrierType = typename MainloopPipeline::ProducerBarrierType;
  208. };
  209. ////////////////////////////////////////////////////////////////////////////////////////////////////
  210. template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  211. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  212. class SmemLayoutdK, class SmemLayoutdV>
  213. struct SharedStorageQKVdOdKV;
  214. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  215. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  216. class SmemLayoutdK, class SmemLayoutdV>
  217. struct SharedStorageQKVdOdKV<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  218. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV> {
  219. struct {
  220. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  221. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  222. union {
  223. struct {
  224. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  225. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  226. };
  227. struct {
  228. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  229. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  230. };
  231. };
  232. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  233. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  234. };
  235. struct {
  236. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  237. cutlass::arch::ClusterTransactionBarrier barrier_K;
  238. cutlass::arch::ClusterTransactionBarrier barrier_V;
  239. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  240. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  241. };
  242. };
  243. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  244. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  245. class SmemLayoutdK, class SmemLayoutdV>
  246. struct SharedStorageQKVdOdKV<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  247. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV> {
  248. struct {
  249. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  250. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  251. union {
  252. struct {
  253. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  254. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  255. };
  256. struct {
  257. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  258. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  259. };
  260. };
  261. union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
  262. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  263. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  264. };
  265. };
  266. struct {
  267. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  268. cutlass::arch::ClusterTransactionBarrier barrier_K;
  269. cutlass::arch::ClusterTransactionBarrier barrier_V;
  270. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  271. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  272. };
  273. };
  274. template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  275. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
  276. class SmemLayoutdK, class SmemLayoutdV>
  277. struct SharedStorageQKVdOdKVWS;
  278. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  279. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
  280. class SmemLayoutdK, class SmemLayoutdV>
  281. struct SharedStorageQKVdOdKVWS<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  282. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV> {
  283. struct {
  284. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  285. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  286. union {
  287. struct {
  288. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  289. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  290. };
  291. struct {
  292. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  293. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  294. };
  295. };
  296. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  297. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  298. cute::array_aligned<float, cute::cosize_v<SmemLayoutdQacc>> smem_dqacc;
  299. cute::array_aligned<float, 128> smem_lse;
  300. cute::array_aligned<float, 128> smem_dpsum;
  301. };
  302. struct {
  303. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  304. cutlass::arch::ClusterTransactionBarrier barrier_K;
  305. cutlass::arch::ClusterTransactionBarrier barrier_V;
  306. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  307. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  308. };
  309. };
  310. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  311. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS, class SmemLayoutdQacc,
  312. class SmemLayoutdK, class SmemLayoutdV>
  313. struct SharedStorageQKVdOdKVWS<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  314. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV> {
  315. struct {
  316. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  317. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  318. union {
  319. struct {
  320. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  321. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  322. };
  323. struct {
  324. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdK>> smem_dk;
  325. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdV>> smem_dv;
  326. };
  327. };
  328. union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
  329. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  330. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  331. };
  332. cute::array_aligned<float, cute::cosize_v<SmemLayoutdQacc>> smem_dqacc;
  333. cute::array_aligned<float, 128> smem_lse;
  334. cute::array_aligned<float, 128> smem_dpsum;
  335. };
  336. struct {
  337. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  338. cutlass::arch::ClusterTransactionBarrier barrier_K;
  339. cutlass::arch::ClusterTransactionBarrier barrier_V;
  340. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_q;
  341. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_do;
  342. };
  343. };
  344. template <bool Has_P_smem, int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  345. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  346. class SmemLayoutdQ>
  347. struct SharedStorageQKVdOdKVSeqqPar;
  348. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  349. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  350. class SmemLayoutdQ>
  351. struct SharedStorageQKVdOdKVSeqqPar<true, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  352. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQ> {
  353. struct {
  354. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  355. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  356. union {
  357. struct {
  358. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  359. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  360. };
  361. struct {
  362. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
  363. };
  364. };
  365. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  366. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  367. };
  368. struct {
  369. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  370. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  371. cutlass::arch::ClusterTransactionBarrier barrier_dO;
  372. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  373. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  374. };
  375. };
  376. template <int kStages, class Element, class OutputType, class SmemLayoutQ, class SmemLayoutdO,
  377. class SmemLayoutK, class SmemLayoutV, class SmemLayoutP, class SmemLayoutdS,
  378. class SmemLayoutdQ>
  379. struct SharedStorageQKVdOdKVSeqqPar<false, kStages, Element, OutputType, SmemLayoutQ, SmemLayoutdO,
  380. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQ> {
  381. struct {
  382. cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
  383. cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
  384. union {
  385. struct {
  386. cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
  387. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdO>> smem_do;
  388. };
  389. struct {
  390. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
  391. };
  392. };
  393. union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used.
  394. cute::array_aligned<Element, cute::cosize_v<SmemLayoutP>> smem_p;
  395. cute::array_aligned<Element, cute::cosize_v<SmemLayoutdS>> smem_ds;
  396. };
  397. };
  398. struct {
  399. cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage.
  400. cutlass::arch::ClusterTransactionBarrier barrier_Q;
  401. cutlass::arch::ClusterTransactionBarrier barrier_dO;
  402. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
  403. typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
  404. };
  405. };
  406. ////////////////////////////////////////////////////////////////////////////////////////////////////
  407. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
  408. bool SdP_swapAB_, bool dKV_swapAB_, bool dQ_swapAB_,
  409. int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
  410. int kClusterN_ = 1, typename elem_type=cutlass::half_t>
  411. struct Flash_bwd_kernel_traits {
  412. using Element = elem_type;
  413. using ElementAccum = float;
  414. using index_t = int64_t;
  415. // The number of threads.
  416. static constexpr int kNWarps = kNWarps_;
  417. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  418. static constexpr int kNThreadsNonWS = 8 * cutlass::NumThreadsPerWarp;
  419. // static constexpr int kNThreadsdQ = cutlass::NumThreadsPerWarpGroup;
  420. static constexpr int kNThreadsdQ = 2 * cutlass::NumThreadsPerWarpGroup;
  421. static_assert(kNWarps_ == 8 || kNWarps_ == 12);
  422. static constexpr bool Is_WS = kNWarps_ >= 12;
  423. static constexpr int kBlockM = kBlockM_;
  424. static constexpr int kBlockN = kBlockN_;
  425. static constexpr int kHeadDim = kHeadDim_;
  426. static_assert(kHeadDim % 32 == 0);
  427. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  428. static constexpr int kClusterN = kClusterN_;
  429. using ClusterShape_MNK = Shape<_1, Int<kClusterN>, _1>;
  430. static constexpr int kStages = 2;
  431. static constexpr bool SdP_swapAB = SdP_swapAB_;
  432. static constexpr bool dKV_swapAB = dKV_swapAB_;
  433. static constexpr bool dQ_swapAB = dQ_swapAB_;
  434. static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV
  435. static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS
  436. using TileShapeAtomSdP = std::conditional_t<
  437. !SdP_swapAB,
  438. Shape<Int<kBlockM>, Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kHeadDim>>,
  439. Shape<Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kBlockM>, Int<kHeadDim>>
  440. >;
  441. using AtomLayoutSdP = std::conditional_t<
  442. !SdP_swapAB,
  443. Layout<Shape<Int<AtomLayoutMSdP>, Int<2 / AtomLayoutMSdP>, _1>>,
  444. Layout<Shape<Int<2 / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>
  445. >;
  446. using TiledMmaSdP = decltype(cute::make_tiled_mma(
  447. cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
  448. AtomLayoutSdP{}));
  449. using TileShapeAtomdKV = std::conditional_t<
  450. !dKV_swapAB,
  451. Shape<Int<kBlockN>, Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockM>>,
  452. Shape<Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockN>, Int<kBlockM>>
  453. >;
  454. using AtomLayoutdKV = std::conditional_t<
  455. !dKV_swapAB,
  456. Layout<Shape<Int<AtomLayoutNdKV>, Int<2 / AtomLayoutNdKV>, _1>>,
  457. Layout<Shape<Int<2 / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>
  458. >;
  459. using TiledMmadKV = decltype(cute::make_tiled_mma(
  460. std::conditional_t<
  461. !SdP_swapAB,
  462. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::MN, GMMA::Major::MN>()),
  463. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::K, GMMA::Major::MN>())
  464. >{},
  465. AtomLayoutdKV{}));
  466. using TileShapeAtomdQ = std::conditional_t<
  467. !dQ_swapAB,
  468. Shape<Int<kBlockM>, Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockN>>,
  469. Shape<Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockM>, Int<kBlockN>>
  470. // Shape<Int<kBlockM>, Int<kHeadDim >, Int<kBlockN>>,
  471. // Shape<Int<kHeadDim>, Int<kBlockM>, Int<kBlockN>>
  472. >;
  473. using AtomLayoutdQ = std::conditional_t<
  474. !dQ_swapAB,
  475. Layout<Shape<Int<AtomLayoutMdQ>, Int<2 / AtomLayoutMdQ>, _1>>,
  476. Layout<Shape<Int<2 / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>
  477. // Layout<Shape<Int<1>, Int<1>, _1>>,
  478. // Layout<Shape<Int<1>, Int<1>, _1>>
  479. >;
  480. static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN;
  481. static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K;
  482. using TiledMmadQ = decltype(cute::make_tiled_mma(
  483. std::conditional_t<
  484. !dQ_swapAB,
  485. std::conditional_t<
  486. Mma_dQ_is_RS,
  487. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>()),
  488. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>())
  489. >,
  490. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::MN, GMMA::Major::K>())
  491. >{},
  492. AtomLayoutdQ{}));
  493. using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
  494. using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
  495. using GmemTiledCopydKV = cute::SM90_TMA_STORE;
  496. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  497. static constexpr bool Has_cp_async = true;
  498. #else
  499. static constexpr bool Has_cp_async = false;
  500. #endif
  501. // For the dot_do_o preprocessing kernel
  502. using Gmem_copy_struct = std::conditional_t<
  503. Has_cp_async,
  504. SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
  505. DefaultCopy
  506. >;
  507. static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
  508. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  509. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
  510. // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
  511. // to affect speed in practice.
  512. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
  513. static_assert(kNThreadsNonWS % kGmemThreadsPerRow == 0, "kNThreadsNonWS must be a multiple of kGmemThreadsPerRow");
  514. using GmemLayoutAtom = Layout<Shape <Int<kNThreadsNonWS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  515. Stride<Int<kGmemThreadsPerRow>, _1>>;
  516. using GmemLayoutAtomdQ = Layout<Shape <Int<kNThreadsdQ / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  517. Stride<Int<kGmemThreadsPerRow>, _1>>;
  518. using GmemTiledCopydO = decltype(
  519. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  520. GmemLayoutAtom{},
  521. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  522. using GmemTiledCopydQ = decltype(
  523. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  524. GmemLayoutAtomdQ{},
  525. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  526. using GmemLayoutAtomdQaccum = std::conditional_t<
  527. kBlockKSmem == 32,
  528. Layout<Shape <Int<kNThreadsdQ / 8>, _8>, // Thread layout, 8 threads per row
  529. Stride< _8, _1>>,
  530. Layout<Shape <Int<kNThreadsdQ / 16>, _16>, // Thread layout, 16 threads per row
  531. Stride< _16, _1>>
  532. >;
  533. using GmemTiledCopydQaccum = decltype(
  534. make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
  535. GmemLayoutAtomdQaccum{},
  536. Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
  537. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  538. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  539. using SmemLayoutQ =
  540. decltype(tile_to_shape(SmemLayoutAtomQ{},
  541. make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  542. using SmemLayoutdO = SmemLayoutQ;
  543. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  544. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  545. using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
  546. using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  547. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  548. using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{})));
  549. using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  550. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  551. using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{})));
  552. using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  553. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  554. using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, select<0, 1>(TileShape_MNK{})));
  555. // using SmemLayoutAtomdQacc = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, ElementAccum,
  556. // decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  557. // using SmemLayoutdQacc = decltype(tile_to_shape(SmemLayoutAtomdQacc{}, select<0, 2>(TileShape_MNK{})));
  558. // Note this is the transpose in terms of the view, not in terms of memory.
  559. using SmemLayoutQt =
  560. decltype(cute::composition(SmemLayoutQ{},
  561. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
  562. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
  563. using SmemLayoutdOt =
  564. decltype(cute::composition(SmemLayoutdO{},
  565. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int<kStages>{}),
  566. make_stride(Int<kBlockM>{}, _1{}, Int<kBlockM * kHeadDim>{}))));
  567. using SmemLayoutKt =
  568. decltype(cute::composition(SmemLayoutK{},
  569. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  570. make_stride(Int<kBlockN>{}, _1{}))));
  571. using SmemLayoutPt =
  572. decltype(cute::composition(SmemLayoutP{},
  573. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  574. make_stride(Int<kBlockM>{}, _1{}))));
  575. using SmemLayoutdSt =
  576. decltype(cute::composition(SmemLayoutdS{},
  577. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  578. make_stride(Int<kBlockM>{}, _1{}))));
  579. // using SmemLayoutdQacct =
  580. // decltype(cute::composition(SmemLayoutdQacc{},
  581. // make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  582. // make_stride(Int<kBlockM>{}, _1{}))));
  583. using SmemLayoutdK = SmemLayoutK;
  584. using SmemLayoutdV = SmemLayoutV;
  585. using SmemLayoutdKt = SmemLayoutKt;
  586. using SmemLayoutdVt = SmemLayoutKt;
  587. static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
  588. using SmemLayoutAtomdQ = decltype(
  589. // composition(Swizzle<kSwizzle, 3, 3>{},
  590. composition(Swizzle<3, 3, 3>{},
  591. Layout<Shape<Int<kNThreadsdQ / 32>, Int<32>>,
  592. Stride<Int<32>, _1>>{}));
  593. using SmemLayoutdQ = decltype(tile_to_shape(
  594. SmemLayoutAtomdQ{},
  595. make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
  596. using SmemLayoutdQt =
  597. decltype(cute::composition(SmemLayoutdQ{},
  598. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  599. make_stride(Int<kBlockM>{}, _1{}))));
  600. static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
  601. using SmemLayoutAtomdQaccTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, ElementAccum,
  602. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  603. using SmemLayoutdQaccTMA = decltype(tile_to_shape(SmemLayoutAtomdQaccTMA{}, select<0, 2>(TileShape_MNK{})));
  604. using SmemLayoutdQacc = SmemLayoutdQ;
  605. using SmemLayoutdQacct = SmemLayoutdQt;
  606. using SmemLayoutdQacc2 = decltype(tile_to_shape(
  607. SmemLayoutAtomdQ{},
  608. make_shape(Int<kBlockM>{}, Int<kHeadDim>{}, _2{})));
  609. // using SmemLayoutdQacc = decltype(tile_to_shape(SmemLayoutAtomdQacc{}, select<0, 2>(TileShape_MNK{})));
  610. // using SmemLayoutdQacct =
  611. // decltype(cute::composition(SmemLayoutdQacc{},
  612. // make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  613. // make_stride(Int<kBlockM>{}, _1{}))));
  614. using RmemTiledCopydQacc = decltype(
  615. make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
  616. GmemLayoutAtomdQaccum{},
  617. Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
  618. // using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  619. using SmemCopyAtomPdS = Copy_Atom<
  620. std::conditional_t<!SdP_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  621. Element>;
  622. using SmemCopyAtomdKV = Copy_Atom<
  623. std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  624. Element>;
  625. using SmemCopyAtomdQ = Copy_Atom<
  626. std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  627. Element>;
  628. using SharedStorage = std::conditional_t<
  629. !Is_WS,
  630. SharedStorageQKVdOdKV<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
  631. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdK, SmemLayoutdV>,
  632. SharedStorageQKVdOdKVWS<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
  633. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQacc, SmemLayoutdK, SmemLayoutdV>
  634. // SmemLayoutK, SmemLayoutV, SmemLayoutdS, SmemLayoutdQacc2, SmemLayoutdK, SmemLayoutdV>
  635. >;
  636. // using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages * 2>;
  637. // using PipelineState = typename cutlass::PipelineState<kStages * 2>;
  638. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  639. };
  640. ////////////////////////////////////////////////////////////////////////////////////////////////////
  641. template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
  642. bool SdP_swapAB_, bool dKV_swapAB_, bool dQ_swapAB_,
  643. int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
  644. int kClusterN_ = 1, typename elem_type=cutlass::half_t>
  645. struct Flash_bwd_seqqpar_kernel_traits {
  646. using Element = elem_type;
  647. using ElementAccum = float;
  648. using index_t = int64_t;
  649. // The number of threads.
  650. static constexpr int kNWarps = kNWarps_;
  651. static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
  652. static_assert(kNWarps_ == 8);
  653. static constexpr int kBlockM = kBlockM_;
  654. static constexpr int kBlockN = kBlockN_;
  655. static constexpr int kHeadDim = kHeadDim_;
  656. static_assert(kHeadDim % 32 == 0);
  657. using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
  658. static constexpr int kClusterN = kClusterN_;
  659. using ClusterShape_MNK = Shape<_1, Int<kClusterN>, _1>;
  660. static constexpr int kStages = 2;
  661. static constexpr bool SdP_swapAB = SdP_swapAB_;
  662. static constexpr bool dKV_swapAB = dKV_swapAB_;
  663. static constexpr bool dQ_swapAB = dQ_swapAB_;
  664. static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV
  665. static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS
  666. using TileShapeAtomSdP = std::conditional_t<
  667. !SdP_swapAB,
  668. Shape<Int<kBlockM>, Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kHeadDim>>,
  669. Shape<Int<kBlockN / (2 / AtomLayoutMSdP)>, Int<kBlockM>, Int<kHeadDim>>
  670. >;
  671. using AtomLayoutSdP = std::conditional_t<
  672. !SdP_swapAB,
  673. Layout<Shape<Int<AtomLayoutMSdP>, Int<2 / AtomLayoutMSdP>, _1>>,
  674. Layout<Shape<Int<2 / AtomLayoutMSdP>, Int<AtomLayoutMSdP>, _1>>
  675. >;
  676. using TiledMmaSdP = decltype(cute::make_tiled_mma(
  677. cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomSdP>(),
  678. AtomLayoutSdP{}));
  679. using TileShapeAtomdKV = std::conditional_t<
  680. !dKV_swapAB,
  681. Shape<Int<kBlockN>, Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockM>>,
  682. Shape<Int<kHeadDim / (2 / AtomLayoutNdKV)>, Int<kBlockN>, Int<kBlockM>>
  683. >;
  684. using AtomLayoutdKV = std::conditional_t<
  685. !dKV_swapAB,
  686. Layout<Shape<Int<AtomLayoutNdKV>, Int<2 / AtomLayoutNdKV>, _1>>,
  687. Layout<Shape<Int<2 / AtomLayoutNdKV>, Int<AtomLayoutNdKV>, _1>>
  688. >;
  689. using TiledMmadKV = decltype(cute::make_tiled_mma(
  690. std::conditional_t<
  691. !SdP_swapAB,
  692. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::MN, GMMA::Major::MN>()),
  693. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdKV, GMMA::Major::K, GMMA::Major::MN>())
  694. >{},
  695. AtomLayoutdKV{}));
  696. using TileShapeAtomdQ = std::conditional_t<
  697. !dQ_swapAB,
  698. Shape<Int<kBlockM>, Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockN>>,
  699. Shape<Int<kHeadDim / (2 / AtomLayoutMdQ)>, Int<kBlockM>, Int<kBlockN>>
  700. >;
  701. using AtomLayoutdQ = std::conditional_t<
  702. !dQ_swapAB,
  703. Layout<Shape<Int<AtomLayoutMdQ>, Int<2 / AtomLayoutMdQ>, _1>>,
  704. Layout<Shape<Int<2 / AtomLayoutMdQ>, Int<AtomLayoutMdQ>, _1>>
  705. >;
  706. static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN;
  707. static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K;
  708. using TiledMmadQ = decltype(cute::make_tiled_mma(
  709. std::conditional_t<
  710. !dQ_swapAB,
  711. std::conditional_t<
  712. Mma_dQ_is_RS,
  713. decltype(cute::GMMA::rs_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>()),
  714. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::K, GMMA::Major::MN>())
  715. >,
  716. decltype(cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShapeAtomdQ, GMMA::Major::MN, GMMA::Major::K>())
  717. >{},
  718. AtomLayoutdQ{}));
  719. using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
  720. using GmemTiledCopyKV = cute::SM90_TMA_LOAD;
  721. using GmemTiledCopydKV = cute::SM90_TMA_STORE;
  722. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  723. static constexpr bool Has_cp_async = true;
  724. #else
  725. static constexpr bool Has_cp_async = false;
  726. #endif
  727. // For the dot_do_o preprocessing kernel
  728. using Gmem_copy_struct = std::conditional_t<
  729. Has_cp_async,
  730. SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
  731. DefaultCopy
  732. >;
  733. static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
  734. static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
  735. static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
  736. // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
  737. // to affect speed in practice.
  738. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
  739. static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
  740. using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
  741. Stride<Int<kGmemThreadsPerRow>, _1>>;
  742. using GmemTiledCopydO = decltype(
  743. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  744. GmemLayoutAtom{},
  745. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  746. using GmemTiledCopydQ = decltype(
  747. make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
  748. GmemLayoutAtom{},
  749. Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
  750. using GmemLayoutAtomdQaccum = std::conditional_t<
  751. kBlockKSmem == 32,
  752. Layout<Shape <_32, _8>, // Thread layout, 8 threads per row
  753. Stride< _8, _1>>,
  754. Layout<Shape <_16, _16>, // Thread layout, 16 threads per row
  755. Stride< _16, _1>>
  756. >;
  757. using GmemTiledCopydQaccum = decltype(
  758. make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
  759. GmemLayoutAtomdQaccum{},
  760. Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
  761. using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  762. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  763. using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
  764. using SmemLayoutdO = SmemLayoutQ;
  765. using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  766. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  767. using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{},
  768. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  769. using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  770. decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
  771. using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{},
  772. make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
  773. using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  774. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  775. using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{})));
  776. using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
  777. decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
  778. using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, select<0, 1>(TileShape_MNK{})));
  779. // Note this is the transpose in terms of the view, not in terms of memory.
  780. using SmemLayoutQt =
  781. decltype(cute::composition(SmemLayoutQ{},
  782. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  783. make_stride(Int<kBlockM>{}, _1{}))));
  784. using SmemLayoutdOt =
  785. decltype(cute::composition(SmemLayoutdO{},
  786. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  787. make_stride(Int<kBlockM>{}, _1{}))));
  788. using SmemLayoutKt =
  789. decltype(cute::composition(SmemLayoutK{},
  790. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
  791. make_stride(Int<kBlockN>{}, _1{}, Int<kBlockN * kHeadDim>{}))));
  792. using SmemLayoutPt =
  793. decltype(cute::composition(SmemLayoutP{},
  794. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  795. make_stride(Int<kBlockM>{}, _1{}))));
  796. using SmemLayoutdSt =
  797. decltype(cute::composition(SmemLayoutdS{},
  798. make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  799. make_stride(Int<kBlockM>{}, _1{}))));
  800. using SmemLayoutdK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{})));
  801. using SmemLayoutdV = SmemLayoutdK;
  802. using SmemLayoutdKt = SmemLayoutKt;
  803. using SmemLayoutdVt = SmemLayoutKt;
  804. using SmemLayoutdQTMA = decltype(tile_to_shape(SmemLayoutAtomK{}, select<0, 2>(TileShape_MNK{})));
  805. static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
  806. using SmemLayoutAtomdQ = decltype(
  807. composition(Swizzle<kSwizzle, 3, 3>{},
  808. Layout<Shape<_8, Int<kBlockKSmem>>,
  809. Stride<Int<kBlockKSmem>, _1>>{}));
  810. using SmemLayoutdQ = decltype(tile_to_shape(
  811. SmemLayoutAtomdQ{},
  812. make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
  813. using SmemLayoutdQt =
  814. decltype(cute::composition(SmemLayoutdQ{},
  815. make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})),
  816. make_stride(Int<kBlockM>{}, _1{}))));
  817. static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
  818. using SmemLayoutAtomdKV = decltype(
  819. composition(Swizzle<kSwizzle, 3, 3>{},
  820. Layout<Shape<_8, Int<kBlockKSmem>>,
  821. Stride<Int<kBlockKSmem>, _1>>{}));
  822. using SmemLayoutdKV = decltype(tile_to_shape(
  823. SmemLayoutAtomdKV{},
  824. make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
  825. using SmemLayoutdKVt =
  826. decltype(cute::composition(SmemLayoutdKV{},
  827. make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
  828. make_stride(Int<kBlockN>{}, _1{}))));
  829. static constexpr int kSmemdKVSize = size(SmemLayoutdKV{}) * sizeof(Element) * 2;
  830. // using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
  831. using SmemCopyAtomPdS = Copy_Atom<
  832. std::conditional_t<!SdP_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  833. Element>;
  834. using SmemCopyAtomdKV = Copy_Atom<
  835. std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  836. Element>;
  837. using SmemCopyAtomdQ = Copy_Atom<
  838. std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
  839. Element>;
  840. using SharedStorage = SharedStorageQKVdOdKVSeqqPar<!SdP_swapAB, kStages, Element, Element, SmemLayoutQ, SmemLayoutdO,
  841. SmemLayoutK, SmemLayoutV, SmemLayoutP, SmemLayoutdS, SmemLayoutdQTMA>;
  842. // using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages * 2>;
  843. // using PipelineState = typename cutlass::PipelineState<kStages * 2>;
  844. using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
  845. };
  846. ////////////////////////////////////////////////////////////////////////////////////////////////////