machete_mainloop.cuh 64 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473
  1. //
  2. // Based off of:
  3. // cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp
  4. // Specifically:
  5. // https://github.com/NVIDIA/cutlass/tree/06b21349bcf6ddf6a1686a47a137ad1446579db9/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp
  6. // Referred to as upstream from in the comments
  7. //
  8. // The main optimization machete implements compared to upstream is to prepack
  9. // the weight matrix to more closely match the shape of the wgmma instructions
  10. // allowing for wider (ideally 128bit) shared memory loads. For subbyte types
  11. // this is done by packing values from multiple wgmma loads (for a single
  12. // thread) into a single 128bit load. This is very similar to layout used in
  13. // Marlin, although specific to the wgmma instructions.
  14. //
  15. // Since the wgmma instructions only support sourcing from registers for the A
  16. // operand, and we want to upconvert/decompress the weight values/elements
  17. // before feeding them into the tensor cores in registers, we need the weight
  18. // matrix to be A. To achieve this we compute the transpose of Y = XW^t as
  19. // Y^t = W^tX^t. This is mostly done outside of this file in
  20. // csrc/quantization/machete/machete_mm_kernel.cuh, but this why A is the
  21. // quantized/narrow type and has the prepacked layout despite the API being:
  22. // B_prepacked = machete_prepack_B(B)
  23. // Y = machete_mm(A, B_prepacked)
  24. //
  25. #pragma once
  26. // clang-format off
  27. #include "cutlass/cutlass.h"
  28. #include "cutlass/numeric_conversion.h"
  29. #include "cute/arch/cluster_sm90.hpp"
  30. #include "cute/arch/copy_sm90.hpp"
  31. #include "cutlass/gemm/gemm.h"
  32. #include "cutlass/detail/dependent_false.hpp"
  33. #include "cutlass/gemm/dispatch_policy.hpp"
  34. #include "cutlass/detail/layout.hpp"
  35. #include "cute/algorithm/functional.hpp"
  36. #include "cute/atom/mma_atom.hpp"
  37. #include "cute/atom/copy_traits_sm90_tma.hpp"
  38. #include "cute/algorithm/gemm.hpp"
  39. #include "cute/tensor_predicate.hpp"
  40. #include "cute/numeric/arithmetic_tuple.hpp"
  41. #include "cutlass/pipeline/pipeline.hpp"
  42. #include "cutlass/transform/collective/sm90_wgmma_transpose.hpp"
  43. #include "cutlass/trace.h"
  44. #include "cutlass/detail/collective.hpp"
  45. // clang-format on
  46. #include "cutlass_extensions/cute_utils.cuh"
  47. namespace machete {
  48. using namespace cute;
  49. using namespace cutlass;
  50. using namespace cutlass::gemm;
  51. using namespace cutlass::gemm::collective;
  52. using namespace cutlass::gemm::collective::detail;
  53. template <class ElementATuple_, class GmemLayoutA, int AlignmentA,
  54. class ElementB_, class GmemLayoutB, int AlignmentB,
  55. class ElementAccumulator_, class TileShape_MNK,
  56. class ClusterShape_MNK, class StageCountType,
  57. class KernelScheduleType>
  58. struct MacheteCollectiveMma {
  59. using Schedule = KernelScheduleType;
  60. static_assert(
  61. cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
  62. cute::is_same_v<Schedule, KernelTmaWarpSpecializedMixedInput> ||
  63. cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
  64. cute::is_same_v<Schedule,
  65. KernelTmaWarpSpecializedPingpongMixedInput> ||
  66. cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative> ||
  67. cute::is_same_v<Schedule,
  68. KernelTmaWarpSpecializedCooperativeMixedInput>,
  69. "KernelSchedule must be one of the warp specialized policies");
  70. public:
  71. static constexpr bool ALayoutIsPrepacked = true;
  72. // Prepacked block shape (N is M in the transposed problem)
  73. using PPBlockShape_MK = typename GmemLayoutA::PPBlockShape_NK;
  74. // Prepacked blocks per dim for a single MMA tile
  75. using PPBlocksPerTile_MK = decltype(make_shape(
  76. size<0>(TileShape_MNK{}) / size<0>(PPBlockShape_MK{}),
  77. size<2>(TileShape_MNK{}) / size<1>(PPBlockShape_MK{})));
  78. using IlvdBlkLayout = typename GmemLayoutA::IlvdBlkLayout;
  79. static_assert(size<0>(TileShape_MNK{}) % size<0>(PPBlockShape_MK{}) == 0,
  80. "M in PPBlockShape_MK must evenly divide M TileShape_MNK");
  81. static_assert(size<2>(TileShape_MNK{}) % size<1>(PPBlockShape_MK{}) == 0,
  82. "K in PPBlockShape_MK must evenly divide K TileShape_MNK");
  83. using ArchTag = arch::Sm90;
  84. using TileShape = TileShape_MNK;
  85. using ClusterShape = ClusterShape_MNK;
  86. using ElementA = deduce_mixed_width_dtype_t<0, ElementATuple_>;
  87. using StrideA = TagToStrideA_t<layout::RowMajor>;
  88. using ElementB = ElementB_;
  89. using StrideB = TagToStrideB_t<GmemLayoutB>;
  90. using ElementAccumulator = ElementAccumulator_;
  91. using ElementMma = ElementB;
  92. using ElementATuple =
  93. cute::conditional_t<!cute::is_tuple<ElementATuple_>::value,
  94. cute::tuple<ElementA>, ElementATuple_>;
  95. static constexpr cute::GMMA::Major GmmaMajorA =
  96. gmma_rs_tag_to_major_A<layout::RowMajor>();
  97. static constexpr cute::GMMA::Major GmmaMajorB =
  98. gmma_rs_tag_to_major_B<GmemLayoutB>();
  99. // For coop schedules we have two warp groups cooperatively issuing wgmma
  100. // instructions so we use 2 atoms along the M dim (one for each warpgroup)
  101. using AtomLayoutMNK = cute::conditional_t<
  102. cute::is_same_v<KernelScheduleType,
  103. KernelTmaWarpSpecializedCooperativeMixedInput>,
  104. Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
  105. using TiledMma = decltype(cute::make_tiled_mma(
  106. cute::GMMA::rs_op_selector<ElementMma, ElementMma, ElementAccumulator,
  107. TileShape_MNK, GMMA::Major::K, GmmaMajorB>(),
  108. AtomLayoutMNK{}));
  109. private:
  110. //
  111. // the setup section (until "section setup end") contains a combination of
  112. // modified code from (used as a starting point):
  113. // `cutlass/gemm/collective/builders/sm90_gmma_builder.inl`
  114. // `cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp`
  115. // (upstream)
  116. //
  117. // however in-order to simplify the code we combine a lot of the logic from
  118. // `CollectiveMma` and `CollectiveBuilder` into this class, this also makes
  119. // sense given that we have flexibility on layouts here. We also simplify the
  120. // code by only supporting scales and zeros for A (in the transposed problem,
  121. // B from an API perspective), also since we force A to be the narrow type
  122. // (i.e. the type to be upconverted) we can remove all the `SwapAB` logic in
  123. // the upstream also simplifying the code. This section includes new logic
  124. // (compared ustream) for handling the prepacked-A layouts (in the transposed
  125. // problem, B from an API perspective)
  126. //
  127. using ElementScale = deduce_mixed_width_dtype_t<1, ElementATuple_>;
  128. using ElementZero = deduce_mixed_width_dtype_t<2, ElementATuple_>;
  129. static constexpr bool IsANarrow = cutlass::sizeof_bits<ElementA>::value <
  130. cutlass::sizeof_bits<ElementB>::value;
  131. static_assert(IsANarrow,
  132. "A must be the narrow one since its the one that flows through "
  133. "registers.");
  134. public:
  135. static constexpr int PipelineStages =
  136. compute_stage_count_or_override_single_affine_transformed_input<
  137. sm90_smem_capacity_bytes, ElementA, ElementB, ElementScale,
  138. ElementZero, TileShape_MNK>(StageCountType{});
  139. struct DispatchPolicy {
  140. constexpr static int Stages = PipelineStages;
  141. using ClusterShape = ClusterShape_MNK;
  142. using Schedule = KernelScheduleType;
  143. };
  144. using GmemTiledCopyA =
  145. decltype(sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
  146. using GmemTiledCopyB =
  147. decltype(sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
  148. // ((T, V), (BlocksM, BlocksK), pipe) -> offset
  149. using SmemLayoutA = decltype(GmemLayoutA::TVbNbKL_to_offset(
  150. make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}),
  151. Int<DispatchPolicy::Stages>{})));
  152. using SmemLayoutAtomARowMajor =
  153. decltype(rs_smem_selector<GmmaMajorA, ElementA,
  154. decltype(cute::get<0>(TileShape_MNK{})),
  155. decltype(cute::get<2>(TileShape_MNK{}))>());
  156. using SmemLayoutAtomScale = Layout<
  157. Shape<decltype(cute::shape<0>(SmemLayoutAtomARowMajor{})), cute::Int<1>>>;
  158. using SmemLayoutAtomB =
  159. decltype(rs_smem_selector<GmmaMajorB, ElementB,
  160. decltype(cute::get<1>(TileShape_MNK{})),
  161. decltype(cute::get<2>(TileShape_MNK{}))>());
  162. using SmemCopyAtomA = Copy_Atom<cute::DefaultCopy, ElementA>;
  163. using SmemCopyAtomB = void;
  164. //
  165. // Validity checks
  166. //
  167. static_assert(is_static<TileShape_MNK>::value);
  168. static_assert(is_static<ClusterShape_MNK>::value);
  169. static_assert(is_aligned<ElementA, AlignmentA, ElementB, AlignmentB,
  170. tma_alignment_bytes>(),
  171. "Should meet TMA alignment requirement\n");
  172. #ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
  173. static_assert(cutlass::detail::dependent_false<ElementA>,
  174. "Unsupported Toolkit for SM90 Collective Builder\n");
  175. #endif
  176. private:
  177. enum class ConversionMode {
  178. DirectConvert,
  179. ConvertAndScale,
  180. ConvertAndScaleWithZero
  181. };
  182. public:
  183. //
  184. // Type Aliases
  185. //
  186. using KernelSchedule = KernelScheduleType;
  187. // For cases where we can't have a void type, we can use this to allow the
  188. // code to compile when the scale / zero is void.
  189. using NonVoidElementScale =
  190. cute::conditional_t<cute::is_void_v<ElementScale>, float, ElementScale>;
  191. using NonVoidElementZero =
  192. cute::conditional_t<cute::is_void_v<ElementZero>, float, ElementZero>;
  193. // These are always MN major
  194. using StrideScale = cute::Stride<cute::Int<1>, int64_t, int64_t>;
  195. // For cases where we can't have a void scale, we can use this to allow the
  196. // code to compile when the scale is void.
  197. using NonVoidStrideScale =
  198. cute::conditional_t<cute::is_void_v<StrideScale>,
  199. cute::Stride<_1, int64_t, int64_t>, StrideScale>;
  200. static_assert((cutlass::gemm::detail::is_k_major<StrideA>()),
  201. "The transformed matrix (A) must be K-major.");
  202. static_assert((sizeof(ElementB) == 2) ||
  203. (cutlass::gemm::detail::is_k_major<StrideA>() &&
  204. cutlass::gemm::detail::is_k_major<StrideB>()),
  205. "The unscaled element (matrix B) must be 2 bytes OR both "
  206. "inputs must be K-major");
  207. static_assert(cutlass::gemm::detail::is_mn_major<NonVoidStrideScale>(),
  208. "Scale must be MN major [Col Major if A is scaled, Row Major "
  209. "if B is scaled].");
  210. static_assert(std::is_same_v<typename TiledMma::ValTypeC, ElementAccumulator>,
  211. "TiledMma::ValTypeC must be the same as ElementAccumulator.");
  212. using GmemTiledCopyScale = cute::SM90_TMA_LOAD;
  213. using SmemCopyAtomScale = Copy_Atom<cute::DefaultCopy, NonVoidElementScale>;
  214. // TMA converts f32 input to tf32 when copying from GMEM to SMEM
  215. // For all other types, cast to size equivalent uint type to avoid any
  216. // rounding by TMA.
  217. static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
  218. static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
  219. using InternalElementA =
  220. cute::conditional_t<ConvertF32toTF32A, tfloat32_t,
  221. uint_bit_t<sizeof_bits_v<ElementA>>>;
  222. using InternalElementB =
  223. cute::conditional_t<ConvertF32toTF32B, tfloat32_t,
  224. uint_bit_t<sizeof_bits_v<ElementB>>>;
  225. using TransformA = cute::identity;
  226. using TransformB = cute::identity;
  227. static constexpr int IsSubbyteA = cute::sizeof_bits_v<InternalElementA> < 8;
  228. using TmaElementA =
  229. cute::conditional_t<IsSubbyteA, uint8_t, InternalElementA>;
  230. using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
  231. using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
  232. using PipelineParams = typename MainloopPipeline::Params;
  233. using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}),
  234. shape<1>(SmemLayoutAtomScale{})));
  235. static_assert(cute::rank(SmemLayoutAtomB{}) == 2,
  236. "SmemLayoutAtom must be rank 2 (M/N, K)");
  237. static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0,
  238. "SmemLayoutAtom must evenly divide tile shape.");
  239. static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0,
  240. "SmemLayoutAtom must evenly divide tile shape.");
  241. static_assert(rank(SmemLayoutAtomScale{}) == 2,
  242. "SmemLayoutAtomScale must be rank 2");
  243. static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0,
  244. "SmemLayoutAtomScale must equal the tile shape.");
  245. static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0,
  246. "SmemLayoutAtomScale must evenly divide tile k shape.");
  247. // Tile along modes in a way that maximizes the TMA box size.
  248. using SmemLayoutACopy = decltype(tile_to_shape(
  249. SmemLayoutAtomARowMajor{},
  250. make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}),
  251. Int<DispatchPolicy::Stages>{}),
  252. conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(),
  253. Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
  254. using SmemLayoutB = decltype(tile_to_shape(
  255. SmemLayoutAtomB{},
  256. make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}),
  257. Int<DispatchPolicy::Stages>{}),
  258. conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(),
  259. Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
  260. // It is assumed that the scales and zero-points share the same smem layout
  261. using SmemLayoutScale = decltype(tile_to_shape(
  262. SmemLayoutAtomScale{},
  263. make_shape(shape<0>(ScaleTileShape{}), shape<1>(ScaleTileShape{}),
  264. Int<PipelineStages>{})));
  265. // If A mn-layout and B mn-layout, transposing B matrix since WGMMA is k-major
  266. // only (e.g. tf32, fp32, fp8, int8).
  267. static constexpr bool IsLayoutAmnBmn =
  268. cute::is_same_v<gemm::detail::StrideToLayoutTagA_t<StrideA>,
  269. layout::ColumnMajor> &&
  270. cute::is_same_v<gemm::detail::StrideToLayoutTagB_t<StrideB>,
  271. layout::RowMajor>;
  272. static_assert(DispatchPolicy::Stages >= 2,
  273. "Specialization requires Stages set to value 2 or more.");
  274. static_assert(not cute::is_base_of<cute::GMMA::DescriptorIterator,
  275. typename TiledMma::FrgTypeA>::value &&
  276. cute::is_base_of<cute::GMMA::DescriptorIterator,
  277. typename TiledMma::FrgTypeB>::value,
  278. "MMA atom must source A from rmem and B operand from smem_desc "
  279. "for this mainloop.");
  280. static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> ||
  281. cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
  282. "GmemTiledCopy - invalid SM90 TMA copy atom specified.");
  283. static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> ||
  284. cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
  285. "GmemTiledCopy - invalid SM90 TMA copy atom specified.");
  286. using GmmaSmemLayoutB = decltype(tile_to_shape(
  287. SmemLayoutAtomB{},
  288. make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}),
  289. Int<DispatchPolicy::Stages>{}),
  290. conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(),
  291. Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
  292. // These two restrictions are related, so we place the assertions together.
  293. // To relax them, we need to handle loading more than 1 row of scales for
  294. // every main loop iteration. We must also handle updating the pipeline
  295. // transaction bytes on the fly. NOTE: Deleting this assertion without
  296. // required changes will cause the code to hang.
  297. static_assert(size<1>(SmemLayoutAtomScale{}) == 1,
  298. "size<1>(SmemLayoutAtomScale) must be 1.");
  299. private:
  300. static constexpr ConversionMode get_conversion_mode() {
  301. if constexpr (cute::is_void_v<ElementScale>) {
  302. return ConversionMode::DirectConvert;
  303. } else if constexpr (cute::is_void_v<ElementZero>) {
  304. return ConversionMode::ConvertAndScale;
  305. } else {
  306. return ConversionMode::ConvertAndScaleWithZero;
  307. }
  308. }
  309. static constexpr ConversionMode KernelConversionMode = get_conversion_mode();
  310. static constexpr bool ModeHasScales =
  311. KernelConversionMode == ConversionMode::ConvertAndScale ||
  312. KernelConversionMode == ConversionMode::ConvertAndScaleWithZero;
  313. // Same as upstream, should be kept the same when possible
  314. static constexpr auto elements_per_smem_scale() {
  315. if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
  316. return 0;
  317. } else if constexpr (ModeHasScales) {
  318. return cute::cosize_v<SmemLayoutScale>;
  319. } else {
  320. static_assert(cutlass::detail::dependent_false<KernelSchedule>,
  321. "Type not handled in scale smem allocation.");
  322. }
  323. }
  324. // Same as upstream, should be kept the same when possible
  325. static constexpr auto elements_per_smem_zero() {
  326. if constexpr (KernelConversionMode == ConversionMode::DirectConvert ||
  327. KernelConversionMode == ConversionMode::ConvertAndScale) {
  328. return 0;
  329. } else if constexpr (KernelConversionMode ==
  330. ConversionMode::ConvertAndScaleWithZero) {
  331. return cute::cosize_v<SmemLayoutScale>;
  332. } else {
  333. static_assert(cutlass::detail::dependent_false<KernelSchedule>,
  334. "Type not handled in scale smem allocation.");
  335. }
  336. }
  337. // Same as upstream, should be kept the same when possible, not formatte for
  338. // easier comparison
  339. // clang-format off
  340. // These methods use some the public members of the class. For that reason, we define them after the public section.
  341. static constexpr uint32_t
  342. compute_tma_transaction_bytes_mk() {
  343. constexpr uint32_t baseline_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(cute::sizeof_bits_v<InternalElementA>));
  344. if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
  345. return baseline_bytes;
  346. }
  347. else if constexpr (ModeHasScales) {
  348. constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast<uint32_t>(cute::sizeof_bits_v<ElementScale>));
  349. static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA
  350. if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
  351. return baseline_bytes + scale_tx_bytes;
  352. }
  353. else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
  354. // Scale and zero share smem layout
  355. constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast<uint32_t>(cute::sizeof_bits_v<ElementZero>));
  356. static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA
  357. return baseline_bytes + scale_tx_bytes + zero_tx_bytes;
  358. }
  359. else {
  360. static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in tma transaction bytes computation.");
  361. }
  362. }
  363. else {
  364. static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in tma transaction bytes computation.");
  365. }
  366. }
  367. static constexpr uint32_t
  368. compute_tma_transaction_bytes_nk() {
  369. return cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(cute::sizeof_bits_v<InternalElementB>));
  370. }
  371. // clang-format on
  372. // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx)
  373. using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset(
  374. make_shape(int32_t(0), int32_t(0), int32_t(0)))));
  375. using ATensor = decltype(make_tensor(
  376. get_logical_ptr(static_cast<InternalElementA const*>(nullptr)),
  377. shape(GmemLayoutA::TVbNbKL_to_offset(
  378. make_shape(int32_t(0), int32_t(0), int32_t(0)))),
  379. PrepackedStrideA{}));
  380. using BTensor = decltype(make_tensor(
  381. get_logical_ptr(static_cast<InternalElementB const*>(nullptr)),
  382. repeat_like(StrideB{}, int32_t(0)), StrideB{}));
  383. using ScaleTensor = decltype(make_tensor(
  384. get_logical_ptr(static_cast<NonVoidElementScale const*>(nullptr)),
  385. repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{}));
  386. using ZeroTensor = decltype(make_tensor(
  387. get_logical_ptr(static_cast<NonVoidElementZero const*>(nullptr)),
  388. repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{}));
  389. static constexpr auto make_tma_copy_A(ATensor tensor_a = ATensor{}) {
  390. return make_tma_copy<TmaElementA>(
  391. GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}),
  392. shape(SmemLayoutA{}(_, _, cute::Int<0>{})),
  393. size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
  394. }
  395. static constexpr auto make_tma_copy_scale(
  396. ScaleTensor tensor_scale = ScaleTensor{}) {
  397. return make_tma_copy(GmemTiledCopyScale{}, tensor_scale,
  398. SmemLayoutScale{}(_, _, cute::Int<0>{}),
  399. ScaleTileShape{},
  400. _1{}); // mcast along N mode for this M load, if any
  401. }
  402. static constexpr auto make_tma_copy_zero(
  403. ZeroTensor tensor_zero = ZeroTensor{}) {
  404. return make_tma_copy(GmemTiledCopyScale{}, tensor_zero,
  405. SmemLayoutScale{}(_, _, cute::Int<0>{}),
  406. ScaleTileShape{},
  407. _1{}); // mcast along N mode for this M load, if any
  408. }
  409. static constexpr auto make_tma_copy_B(BTensor tensor_b = BTensor{}) {
  410. return make_tma_copy(
  411. GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}),
  412. make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
  413. size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
  414. }
  415. public:
  416. // Same as upstream, should be kept the same when possible, not formatted for
  417. // easier comparison
  418. // with `RealInternalElementA` -> `ElementA` since we support `SwapAB` logic
  419. // clang-format off
  420. static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{});
  421. static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{});
  422. // Just pick the max alignment of A and B since it is required to be at least 128B
  423. static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB);
  424. static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment");
  425. struct SharedStorage
  426. {
  427. static constexpr int scale_elements = elements_per_smem_scale();
  428. static constexpr int zero_elements = elements_per_smem_zero();
  429. struct TensorStorage : cute::aligned_struct<cute::max(SmemAlignmentA, SmemAlignmentB)> {
  430. cute::ArrayEngine<ElementA, cute::cosize_v<SmemLayoutA>> smem_A;
  431. cute::ArrayEngine<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
  432. cute::ArrayEngine<NonVoidElementScale, scale_elements> smem_scale;
  433. cute::ArrayEngine<NonVoidElementZero, zero_elements> smem_zero;
  434. } tensors;
  435. using PipelineStorage = typename MainloopPipeline::SharedStorage;
  436. PipelineStorage pipeline;
  437. };
  438. using TensorStorage = typename SharedStorage::TensorStorage;
  439. using PipelineStorage = typename SharedStorage::PipelineStorage;
  440. // Host side kernel arguments
  441. struct Arguments {
  442. ElementA const* ptr_A = nullptr;
  443. StrideA dA{};
  444. ElementB const* ptr_B = nullptr;
  445. StrideB dB{};
  446. ElementScale const* ptr_S = nullptr;
  447. NonVoidStrideScale dS{};
  448. int group_size = 0;
  449. ElementZero const* ptr_Z = nullptr;
  450. uint32_t mma_promotion_interval = 4;
  451. };
  452. // clang-format on
  453. //
  454. // section setup end
  455. //
  456. // Similar (but not idendtical) to upstream, should be kept the same when
  457. // possible
  458. // compared to upstream we use `make_tma_copy_A`, `make_tma_copy_B` etc. to
  459. // define the TMA types
  460. // Device side kernel params
  461. struct Params {
  462. public:
  463. // Assumption: StrideA is congruent with Problem_MK
  464. using TMA_A = decltype(make_tma_copy_A());
  465. using TMA_Scale = decltype(make_tma_copy_scale());
  466. using TMA_Zero = decltype(make_tma_copy_zero());
  467. using TMA_B = decltype(make_tma_copy_B());
  468. // required by outer loop: i.e.
  469. // cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp
  470. TMA_A tma_load_a;
  471. TMA_B tma_load_b;
  472. TMA_Scale tma_load_scale;
  473. TMA_Zero tma_load_zero;
  474. int64_t scale_k;
  475. int group_size;
  476. uint32_t tma_transaction_bytes = TmaTransactionBytes;
  477. uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
  478. uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
  479. };
  480. //
  481. // Methods
  482. //
  483. // Similar (but not idendtical) to upstream, should be kept the same when
  484. // possible
  485. // compared to upstream we use `make_tma_copy_A` and `TVbNbKL_to_offset` here
  486. // to handle the prepacked layout
  487. template <class ProblemShape>
  488. static constexpr Params to_underlying_arguments(
  489. ProblemShape const& problem_shape, Arguments const& args,
  490. void* workspace) {
  491. (void)workspace;
  492. // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is
  493. // only rank-3 (MNK)
  494. auto problem_shape_MNKL = append<4>(problem_shape, 1);
  495. auto [M, N, K, L] = problem_shape_MNKL;
  496. auto ptr_A = reinterpret_cast<InternalElementA const*>(args.ptr_A);
  497. auto ptr_B = reinterpret_cast<InternalElementB const*>(args.ptr_B);
  498. auto make_logical_tensor = [&](auto ptr, auto shape, auto stride) {
  499. return make_tensor(get_logical_ptr(ptr), make_layout(shape, stride));
  500. };
  501. typename Params::TMA_A tma_load_a;
  502. typename Params::TMA_B tma_load_b;
  503. typename Params::TMA_Scale tma_load_scale;
  504. typename Params::TMA_Zero tma_load_zero;
  505. auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L));
  506. tma_load_a = make_tma_copy_A(
  507. make_logical_tensor(ptr_A, shape(layout), stride(layout)));
  508. tma_load_b = make_tma_copy_B(
  509. make_logical_tensor(ptr_B, make_shape(N, K, L), args.dB));
  510. if constexpr (ModeHasScales) {
  511. tma_load_scale = make_tma_copy_scale(make_logical_tensor(
  512. args.ptr_S, make_shape(M, args.group_size, L), args.dS));
  513. }
  514. if constexpr (KernelConversionMode ==
  515. ConversionMode::ConvertAndScaleWithZero) {
  516. tma_load_zero = make_tma_copy_zero(make_logical_tensor(
  517. args.ptr_Z, make_shape(M, args.group_size, L), args.dS));
  518. }
  519. if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
  520. return {tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0};
  521. } else if constexpr (ModeHasScales) {
  522. auto scale_k = (K + args.group_size - 1) / args.group_size;
  523. return {tma_load_a, tma_load_b, tma_load_scale,
  524. tma_load_zero, scale_k, args.group_size};
  525. } else {
  526. static_assert(cutlass::detail::dependent_false<KernelSchedule>,
  527. "Conversion mode not handled in to_underlying_arguments.");
  528. }
  529. }
  530. // Same as upstream, should be kept the same when possible, not formatted for
  531. // easier comparison
  532. // with `SwapAB ? N : M -> M` since we dont support SwapAB
  533. // clang-format off
  534. template<class ProblemShape>
  535. static bool
  536. can_implement(
  537. ProblemShape const& problem_shape,
  538. [[maybe_unused]] Arguments const& args) {
  539. constexpr int tma_alignment_bits = 128;
  540. auto problem_shape_MNKL = append<4>(problem_shape, 1);
  541. auto [M,N,K,L] = problem_shape_MNKL;
  542. bool implementable = true;
  543. constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
  544. implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
  545. constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
  546. implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
  547. if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
  548. implementable = implementable && (args.ptr_S == nullptr);
  549. implementable = implementable && (args.ptr_Z == nullptr);
  550. }
  551. else if constexpr (ModeHasScales) {
  552. const int scale_mn = M;
  553. const int scale_k = (K + args.group_size - 1) / args.group_size;
  554. constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits<ElementScale>::value;
  555. implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_scale>(cute::make_shape(scale_mn,scale_k,L), StrideScale{});
  556. implementable = implementable && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0));
  557. implementable = implementable && args.group_size != 0;
  558. implementable = implementable && (args.ptr_S != nullptr);
  559. if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
  560. implementable = implementable && (args.ptr_Z == nullptr);
  561. }
  562. else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
  563. constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits<ElementZero>::value;
  564. implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_zero>(cute::make_shape(scale_mn,scale_k,L), StrideScale{});
  565. implementable = implementable && (args.ptr_Z != nullptr);
  566. }
  567. else {
  568. static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in can_implement.");
  569. }
  570. }
  571. else {
  572. static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in can_implement.");
  573. }
  574. if (!implementable) {
  575. CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
  576. }
  577. return implementable;
  578. }
  579. static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
  580. static constexpr uint32_t TmaTransactionBytesMK = compute_tma_transaction_bytes_mk();
  581. static constexpr uint32_t TmaTransactionBytesNK = compute_tma_transaction_bytes_nk();
  582. static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
  583. /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
  584. CUTLASS_DEVICE
  585. static void prefetch_tma_descriptors(Params const& mainloop_params) {
  586. cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
  587. cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
  588. if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
  589. // Nothing extra to do
  590. }
  591. else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
  592. cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor());
  593. }
  594. else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
  595. cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor());
  596. cute::prefetch_tma_descriptor(mainloop_params.tma_load_zero.get_tma_descriptor());
  597. }
  598. else {
  599. static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in TMA prefetch.");
  600. }
  601. }
  602. // clang-format off
  603. // Modified from upstream, should be kept close to that when possible
  604. // the main difference is special handling for the prepacked A layout
  605. //
  606. // Set up the data needed by this collective for load and mma.
  607. // Returns a tuple of tensors. The collective and the kernel layer have the
  608. // contract Returned tuple must contain at least two elements, with the first
  609. // two elements being: gA_mkl - The tma tensor, A after a local tile so it
  610. // has shape (TILE_V,TILE_B,m,k,l) gB_nkl - The tma tensor, B after a local
  611. // tile so it has shape (TILE_N,TILE_K,n,k,l) The rest of the tensors can be
  612. // specified as needed by this collective.
  613. // NOTE: TILE_B is the prepacked block index within a tile. TILE_V is the
  614. // values within a prepacked block.
  615. template <class ProblemShape_MNKL>
  616. CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL,
  617. Params const& mainloop_params) const {
  618. using X = Underscore;
  619. auto M = get<0>(problem_shape_MNKL), N = get<1>(problem_shape_MNKL),
  620. K = get<2>(problem_shape_MNKL), L = get<3>(problem_shape_MNKL);
  621. // (TILE_V,TILE_B,m,k,l)
  622. auto make_gA_mkl = [&]() {
  623. // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx)
  624. auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L));
  625. Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(layout));
  626. return local_tile(mA_mkl,
  627. make_shape(size<0>(layout), PPBlocksPerTile_MK{}),
  628. make_coord(0, make_coord(_, _)));
  629. };
  630. // (TILE_N,TILE_K,n,k,l)
  631. auto make_gB_nkl = [&]() {
  632. Tensor mB_nkl =
  633. mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L));
  634. return local_tile(mB_nkl, TileShape{}, make_coord(_, _, _),
  635. Step<X, _1, _1>{});
  636. };
  637. // (TILE_M,TILE_Scale_K,m,scale_k,l)
  638. auto make_gS_mkl = [&]() {
  639. auto scale_k = mainloop_params.scale_k;
  640. Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(
  641. make_shape(M, scale_k, L));
  642. return local_tile(mS_mkl, ScaleTileShape{}, make_coord(_, _));
  643. };
  644. // (TILE_M,TILE_Scale_K,m,scale_k,l)
  645. auto make_gZ_mkl = [&]() {
  646. auto scale_k = mainloop_params.scale_k;
  647. Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(
  648. make_shape(M, scale_k, L));
  649. return local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_, _));
  650. };
  651. if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
  652. return cute::make_tuple(make_gA_mkl(), make_gB_nkl());
  653. } else if constexpr (KernelConversionMode ==
  654. ConversionMode::ConvertAndScale) {
  655. return cute::make_tuple(make_gA_mkl(), make_gB_nkl(), make_gS_mkl());
  656. } else if constexpr (KernelConversionMode ==
  657. ConversionMode::ConvertAndScaleWithZero) {
  658. return cute::make_tuple(make_gA_mkl(), make_gB_nkl(), make_gS_mkl(),
  659. make_gZ_mkl());
  660. } else {
  661. static_assert(cutlass::detail::dependent_false<KernelSchedule>,
  662. "Conversion mode not handled in load_init.");
  663. }
  664. }
  665. // Similar to upstream, should be kept close to that when possible
  666. // the main difference is in the layout comments
  667. // clang-format off
  668. /// Perform a collective-scoped matrix multiply-accumulate
  669. /// Producer Perspective
  670. /// This overload gets triggered when we have scales.
  671. template <
  672. class... Ts,
  673. class KTileIterator, class BlockCoord
  674. >
  675. CUTLASS_DEVICE void
  676. load(
  677. Params const& mainloop_params,
  678. MainloopPipeline pipeline,
  679. PipelineState smem_pipe_write,
  680. cute::tuple<Ts...> const& load_inputs,
  681. BlockCoord const& blk_coord,
  682. KTileIterator k_tile_iter, int k_tile_count,
  683. int thread_idx,
  684. uint32_t block_rank_in_cluster,
  685. TensorStorage& shared_tensors) {
  686. if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
  687. static_assert(sizeof... (Ts) == 2, "Direct convert needs two inputs");
  688. }
  689. else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
  690. static_assert(sizeof... (Ts) == 3, "Scaled convert needs three inputs");
  691. }
  692. else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
  693. static_assert(sizeof... (Ts) == 4, "Scaled and zero convert needs four inputs");
  694. }
  695. else {
  696. static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in TMA load.");
  697. }
  698. int lane_predicate = cute::elect_one_sync();
  699. if (lane_predicate) {
  700. Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
  701. Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
  702. Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE)
  703. Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE)
  704. //
  705. // Prepare the TMA loads for A, B and Scales
  706. //
  707. constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
  708. uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
  709. Tensor gA_mkl = get<0>(load_inputs);
  710. Tensor gB_nkl = get<1>(load_inputs);
  711. auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
  712. auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
  713. // Partition the inputs based on the current block coordinates.
  714. auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
  715. Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (TILE_V,TILE_B,k)
  716. Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (TILE_N,TILE_K,k)
  717. // Applies the mapping from block_tma_a
  718. Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
  719. Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
  720. Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
  721. Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
  722. uint16_t mcast_mask_a = 0;
  723. uint16_t mcast_mask_b = 0;
  724. uint16_t mcast_mask_s = 0;
  725. // Issue TmaLoads
  726. // Maps the tile -> block, value
  727. if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
  728. auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
  729. for (int n = 0; n < size<1>(block_layout); ++n) {
  730. mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
  731. }
  732. }
  733. if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
  734. auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
  735. for (int m = 0; m < size<0>(block_layout); ++m) {
  736. mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
  737. }
  738. }
  739. auto extra_input_partitions = partition_extra_tma_inputs(mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord);
  740. // Mainloop
  741. CUTLASS_PRAGMA_NO_UNROLL
  742. for ( ; k_tile_count > 0; --k_tile_count) {
  743. // LOCK smem_pipe_write for _writing_
  744. pipeline.producer_acquire(smem_pipe_write);
  745. //
  746. // Copy gmem to smem for *k_tile_iter
  747. //
  748. using BarrierType = typename MainloopPipeline::ProducerBarrierType;
  749. BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
  750. int write_stage = smem_pipe_write.index();
  751. copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
  752. copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
  753. if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
  754. // Nothing extra to do.
  755. }
  756. else if constexpr (ModeHasScales) {
  757. auto tSgS = get<0>(extra_input_partitions);
  758. auto tSsS = get<1>(extra_input_partitions);
  759. // Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify tma transaction bytes
  760. // on the fly.
  761. // We must do a ceiling divide here to correctly handle with group_size == K. In that case, we don't require that K
  762. // is a multiple of the threadblock tile K
  763. const int ReloadFactor = (mainloop_params.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{});
  764. const int scale_load_k = *k_tile_iter / ReloadFactor; // This will always be 0 when group_size == K.
  765. copy(mainloop_params.tma_load_scale.with(*tma_barrier, mcast_mask_s), tSgS(_,_,_,scale_load_k), tSsS(_,_,_,write_stage));
  766. if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
  767. // Nothing extra to do
  768. }
  769. else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
  770. auto tZgZ = get<2>(extra_input_partitions);
  771. auto tZsZ = get<3>(extra_input_partitions);
  772. copy(mainloop_params.tma_load_zero.with(*tma_barrier, mcast_mask_s), tZgZ(_,_,_,scale_load_k), tZsZ(_,_,_,write_stage));
  773. }
  774. else {
  775. static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for TMA copy op.");
  776. }
  777. }
  778. else {
  779. static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for TMA copy op.");
  780. }
  781. ++k_tile_iter;
  782. // Advance smem_pipe_write
  783. ++smem_pipe_write;
  784. }
  785. }
  786. }
  787. // clang-format off
  788. // Same as upstream, should be kept the same when possible, not formatted for
  789. // easier comparison
  790. // clang-format off
  791. // Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
  792. CUTLASS_DEVICE void
  793. load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) {
  794. int lane_predicate = cute::elect_one_sync();
  795. // Issue the epilogue waits
  796. if (lane_predicate) {
  797. /* This helps avoid early exit of blocks in Cluster
  798. * Waits for all stages to either be released (all
  799. * Consumer UNLOCKs), or if the stage was never used
  800. * then would just be acquired since the phase was
  801. * still inverted from make_producer_start_state
  802. */
  803. pipeline.producer_tail(smem_pipe_write);
  804. }
  805. }
  806. // clang-format on
  807. // Modified from upstream, should be kept close to that when possible
  808. // the main differences are handling the prepacked A layout, and separating
  809. // the loading of A from upcoverting A
  810. //
  811. // Perform a collective-scoped matrix multiply-accumulate
  812. // Consumer Perspective
  813. template <class FrgTensorC>
  814. CUTLASS_DEVICE void mma(MainloopPipeline pipeline,
  815. PipelineState smem_pipe_read, FrgTensorC& accum,
  816. int k_tile_count, int thread_idx,
  817. TensorStorage& shared_tensors,
  818. Params const& mainloop_params) {
  819. static_assert(is_rmem<FrgTensorC>::value,
  820. "C tensor must be rmem resident.");
  821. static_assert(cute::rank(SmemLayoutB{}) == 3,
  822. "Smem layout must be rank 3.");
  823. static_assert(cute::rank(SmemLayoutAtomB{}) == 2,
  824. "SmemLayoutAtomB must be rank 2.");
  825. static_assert(!cute::is_void_v<SmemCopyAtomA>,
  826. "SM90 GMMA mainloops must specify a non-void copy atom for "
  827. "RF sourced instructions.");
  828. static_assert(cute::is_void_v<SmemCopyAtomB>,
  829. "SM90 GMMA mainloops cannot have a non-void copy atom for "
  830. "smem sourced instructions.");
  831. // Obtain warp index
  832. int warp_idx = canonical_warp_idx_sync();
  833. [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128;
  834. // ((T, (FrgV,(RestM, RestK)), (BlocksM, BlocksK), pipe) -> offset
  835. auto constexpr smem_A = SmemLayoutA{};
  836. // convert:
  837. // ((T, (MMA,(MMA_M, MMA_K)), (BlocksM, BlocksK), pipe) -> offset
  838. // to:
  839. // (T, MMA, ((MMA_M, BlocksM), (MMA_K, BlocksK)), pipe) -> offset
  840. // which can be thought of as:
  841. // (T, MMA, (MMA_M, MMA_K), pipe) -> offset
  842. auto constexpr smem_A_mma_ =
  843. make_layout(get<0, 0>(smem_A), get<0, 1, 0>(smem_A),
  844. zip(get<0, 1, 1>(smem_A), get<1>(smem_A)), get<2>(smem_A));
  845. // flatten to:
  846. // (T, MMA, MMA_M, MMA_K, pipe) -> offset
  847. auto constexpr smem_A_mma = smem_A_mma_(_, _, make_coord(_, _), _);
  848. Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()),
  849. smem_A_mma); // (T, MMA, MMA_M, MMA_K, pipe)
  850. Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()),
  851. SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
  852. //
  853. // Define C accumulators and A/B partitioning
  854. //
  855. TiledMma tiled_mma;
  856. auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
  857. Tensor tCsA = sA(thread_idx, _, _, _, _); // (MMA,MMA_M,MMA_K,PIPE)
  858. Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
  859. // Allocate fragments and descriptors
  860. Tensor tCrA_load = make_tensor<ElementA>(
  861. tCsA(_, _, _, Int<0>{}).shape()); // (MMA,MMA_N,MMA_K)
  862. Tensor tCrA_mma = make_fragment_like<ElementMma>(tCrA_load);
  863. Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
  864. static constexpr int A_CPY_VEC =
  865. decltype(max_common_vector(tCsA, tCrA_load)){};
  866. static constexpr int COVERSION_WIDTH =
  867. std::min(A_CPY_VEC, int(size<0>(tCrA_mma)));
  868. auto load_A_to_registers = [&](int read_stage) {
  869. copy(create_auto_vectorizing_copy<ElementA, decltype(A_CPY_VEC)>(),
  870. tCsA(_, _, _, read_stage), tCrA_load(_, _, _));
  871. };
  872. // Partition of thread -> shared and thread -> RF
  873. auto partitioned_extra_info =
  874. partition_extra_mma_info(thread_mma, shared_tensors);
  875. auto copy_partitions_extra_info = retile_extra_mma_info(
  876. tiled_mma, partitioned_extra_info, warp_group_thread_idx);
  877. CUTE_STATIC_ASSERT_V(size<1>(tCrA_mma) == size<1>(accum)); // MMA_M
  878. CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
  879. CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
  880. CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
  881. CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
  882. //
  883. // PIPELINED MAIN LOOP
  884. //
  885. auto convert_A = [&, a_vec = Int<COVERSION_WIDTH>{}](int k_block,
  886. int read_stage) {
  887. load_extra_info_to_registers(partitioned_extra_info,
  888. copy_partitions_extra_info, k_block,
  889. read_stage);
  890. transform_A_kblock(tCrA_load, a_vec, tCrA_mma, partitioned_extra_info,
  891. k_block);
  892. };
  893. // We release buffers to producer warps(dma load) with some mmas in flight
  894. PipelineState smem_pipe_release = smem_pipe_read;
  895. tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
  896. warpgroup_fence_operand(accum);
  897. constexpr int K_BLOCK_MAX = size<2>(tCrA_load);
  898. ConsumerToken barrier_token = {BarrierStatus::WaitAgain};
  899. // first k tile
  900. {
  901. barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
  902. pipeline.consumer_wait(smem_pipe_read, barrier_token);
  903. int read_stage = smem_pipe_read.index();
  904. ++smem_pipe_read;
  905. barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
  906. // copy smem->rmem for A operand
  907. load_A_to_registers(read_stage);
  908. convert_A(0, read_stage);
  909. // Unroll the K mode manually to set scale D to 1
  910. CUTLASS_PRAGMA_UNROLL
  911. for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
  912. if (k_block < K_BLOCK_MAX - 1) {
  913. convert_A(k_block + 1, smem_pipe_read.index());
  914. }
  915. warpgroup_arrive();
  916. // (V,M) x (V,N) => (V,M,N)
  917. cute::gemm(tiled_mma, tCrA_mma(_, _, k_block),
  918. tCrB(_, _, k_block, read_stage), accum);
  919. tiled_mma.accumulate_ = GMMA::ScaleOut::One;
  920. warpgroup_commit_batch();
  921. }
  922. --k_tile_count;
  923. if (k_tile_count > 0) {
  924. // Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to
  925. // overwrite the A registers for the first mma.
  926. warpgroup_wait<K_BLOCK_MAX - 1>();
  927. pipeline.consumer_wait(smem_pipe_read, barrier_token);
  928. load_A_to_registers(smem_pipe_read.index());
  929. convert_A(0, smem_pipe_read.index());
  930. }
  931. }
  932. if (k_tile_count == 0) {
  933. return;
  934. }
  935. warpgroup_fence_operand(accum);
  936. // Mainloop GMMAs
  937. CUTLASS_PRAGMA_NO_UNROLL
  938. for (; k_tile_count > 1; --k_tile_count) {
  939. //
  940. // Compute on k_tile
  941. //
  942. int read_stage = smem_pipe_read.index();
  943. ++smem_pipe_read;
  944. warpgroup_fence_operand(accum);
  945. // Unroll the K mode manually to set scale D to 1
  946. CUTLASS_PRAGMA_UNROLL
  947. for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
  948. warpgroup_arrive();
  949. // (V,M) x (V,N) => (V,M,N)
  950. cute::gemm(tiled_mma, tCrA_mma(_, _, k_block),
  951. tCrB(_, _, k_block, read_stage), accum);
  952. tiled_mma.accumulate_ = GMMA::ScaleOut::One;
  953. warpgroup_commit_batch();
  954. warpgroup_wait<K_BLOCK_MAX - 1>();
  955. if (k_block == K_BLOCK_MAX - 1) {
  956. // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage,
  957. // so we can release prior barrier
  958. pipeline.consumer_release(
  959. smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_
  960. // on it
  961. ++smem_pipe_release;
  962. }
  963. if (k_block == 0) {
  964. barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
  965. }
  966. if (k_block == K_BLOCK_MAX - 1) {
  967. pipeline.consumer_wait(smem_pipe_read, barrier_token);
  968. load_A_to_registers(smem_pipe_read.index());
  969. convert_A(0, smem_pipe_read.index());
  970. } else {
  971. convert_A(k_block + 1, read_stage);
  972. }
  973. }
  974. warpgroup_fence_operand(accum);
  975. }
  976. warpgroup_fence_operand(accum);
  977. {
  978. //
  979. // Compute on k_tile
  980. //
  981. int read_stage = smem_pipe_read.index();
  982. warpgroup_fence_operand(accum);
  983. // Unroll the K mode manually to set scale D to 1
  984. CUTLASS_PRAGMA_UNROLL
  985. for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
  986. warpgroup_arrive();
  987. // (V,M) x (V,N) => (V,M,N)
  988. cute::gemm(tiled_mma, tCrA_mma(_, _, k_block),
  989. tCrB(_, _, k_block, read_stage), accum);
  990. tiled_mma.accumulate_ = GMMA::ScaleOut::One;
  991. warpgroup_commit_batch();
  992. warpgroup_wait<K_BLOCK_MAX - 1>();
  993. if (k_block == K_BLOCK_MAX - 1) {
  994. // release prior barrier
  995. pipeline.consumer_release(
  996. smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_
  997. // on it
  998. ++smem_pipe_release;
  999. }
  1000. if (k_block < K_BLOCK_MAX - 1) {
  1001. convert_A(k_block + 1, read_stage);
  1002. }
  1003. }
  1004. }
  1005. warpgroup_fence_operand(accum);
  1006. }
  1007. // Perform a Consumer Epilogue to release all buffers
  1008. CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline,
  1009. PipelineState smem_pipe_release,
  1010. int k_tile_count) {
  1011. // Prologue GMMAs
  1012. int prologue_mma_count = 1;
  1013. k_tile_count -= prologue_mma_count;
  1014. smem_pipe_release.advance(k_tile_count);
  1015. // Wait on all GMMAs to complete
  1016. warpgroup_wait<0>();
  1017. for (int count = 0; count < prologue_mma_count; ++count) {
  1018. pipeline.consumer_release(
  1019. smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on
  1020. // it
  1021. ++smem_pipe_release;
  1022. }
  1023. }
  1024. private:
  1025. // Same as upstream, should be kept the same when possible, not formatted for
  1026. // easier comparison
  1027. // clang-format off
  1028. /// Utilities for any additional inputs inside of the TMA load
  1029. template <class... Ts>
  1030. CUTLASS_DEVICE
  1031. auto partition_extra_tma_inputs(
  1032. Params const& mainloop_params,
  1033. cute::tuple<Ts...> const& load_inputs,
  1034. TensorStorage& shared_tensors,
  1035. uint2 const& cluster_local_block_id,
  1036. int const m_coord,
  1037. int const l_coord) {
  1038. if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
  1039. return cute::make_tuple();
  1040. }
  1041. else if constexpr (ModeHasScales) {
  1042. Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE)
  1043. Tensor gS_mkl = get<2>(load_inputs);
  1044. auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y);
  1045. Tensor gS = gS_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
  1046. Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k)
  1047. Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE)
  1048. if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
  1049. return cute::make_tuple(tSgS, tSsS);
  1050. }
  1051. else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
  1052. Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE)
  1053. Tensor gZ_mkl = get<3>(load_inputs);
  1054. auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y);
  1055. Tensor gZ = gZ_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
  1056. Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k)
  1057. Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE)
  1058. return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ);
  1059. }
  1060. else {
  1061. static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for input partitioning.");
  1062. }
  1063. }
  1064. else {
  1065. static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for input partitioning.");
  1066. }
  1067. }
  1068. // clang-format off
  1069. // Same as upstream, should be kept the same when possible, not formatted for
  1070. // easier comparison
  1071. // clang-format off
  1072. /// Utilities for partitioning extra inputs for loading from smem in the mainloop.
  1073. template <class ThreadMma>
  1074. CUTLASS_DEVICE
  1075. auto partition_extra_mma_info(
  1076. ThreadMma const& mma_thread_slice,
  1077. TensorStorage& shared_tensors) {
  1078. if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
  1079. // nothing to do
  1080. return cute::make_tuple();
  1081. }
  1082. else if constexpr (ModeHasScales) {
  1083. Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE)
  1084. Tensor tCsS = mma_thread_slice.partition_A(sS);
  1085. Tensor tCrS = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).shape());
  1086. if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
  1087. return cute::make_tuple(tCsS, tCrS);
  1088. }
  1089. else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
  1090. Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE)
  1091. Tensor tCsZ = mma_thread_slice.partition_A(sZ);
  1092. Tensor tCrZ = make_tensor<ElementZero>(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).shape());
  1093. return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ);
  1094. }
  1095. else {
  1096. static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
  1097. }
  1098. }
  1099. else {
  1100. static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
  1101. }
  1102. }
  1103. // clang-format on
  1104. // Same as upstream, should be kept the same when possible, not formatted for
  1105. // easier comparison
  1106. // clang-format off
  1107. /// Returns the tiled copy and copy views for the extra inputs.
  1108. template <class TiledMma, class... Ts>
  1109. CUTLASS_DEVICE
  1110. auto retile_extra_mma_info(
  1111. TiledMma const& tiled_mma,
  1112. cute::tuple<Ts...>& partitioned_extra_info,
  1113. int const warp_group_thread_idx) {
  1114. if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
  1115. // nothing to do
  1116. return cute::make_tuple();
  1117. }
  1118. else if constexpr (ModeHasScales) {
  1119. auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma);
  1120. auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx);
  1121. Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K)
  1122. if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
  1123. return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view);
  1124. }
  1125. else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
  1126. Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K)
  1127. return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view);
  1128. }
  1129. else {
  1130. static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
  1131. }
  1132. }
  1133. else {
  1134. static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
  1135. }
  1136. }
  1137. // clang-format on
  1138. // Similar to `copy_A_and_extra_info` upstream, should be kept the same when
  1139. // possible
  1140. // the main differences this only loads the extra info into registers and
  1141. // not A (since we now preload more of A in the main pipeline)
  1142. // Load scales and zeros into registers if required
  1143. template <class... Ts, class... Us>
  1144. CUTLASS_DEVICE void load_extra_info_to_registers(
  1145. cute::tuple<Ts...> const& partitioned_mma_extra_info,
  1146. cute::tuple<Us...> const& tiled_copy_and_views, int k_block,
  1147. int read_stage) {
  1148. if (k_block == 0) {
  1149. // We are starting a new k-tile so copy the scale
  1150. if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
  1151. // nothing to do
  1152. } else if constexpr (ModeHasScales) {
  1153. auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views);
  1154. auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views);
  1155. auto tCsS = cute::get<0>(partitioned_mma_extra_info);
  1156. copy(smem_tiled_copy_S, tCsS(_, _, k_block, read_stage),
  1157. tCrS_copy_view(_, _, k_block));
  1158. if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
  1159. // Nothing extra to do
  1160. } else if constexpr (KernelConversionMode ==
  1161. ConversionMode::ConvertAndScaleWithZero) {
  1162. auto tCsZ = cute::get<2>(partitioned_mma_extra_info);
  1163. auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views);
  1164. copy(smem_tiled_copy_S, tCsZ(_, _, k_block, read_stage),
  1165. tCrZ_copy_view(_, _, k_block));
  1166. } else {
  1167. static_assert(cutlass::detail::dependent_false<KernelSchedule>,
  1168. "Conversion mode not handled in A -> RF path.");
  1169. }
  1170. } else {
  1171. static_assert(cutlass::detail::dependent_false<KernelSchedule>,
  1172. "Conversion mode not handled in A -> RF path.");
  1173. }
  1174. }
  1175. }
  1176. // Similar to upstream, should be kept the same when possible.
  1177. // the main differences are that `convert_tensor` supports interleaved
  1178. // layouts and bfloat16 has been optimized. `transform_internal_A` has also
  1179. // been inlined for code simplicity.
  1180. // Utilities to transform A.
  1181. template <class TCrA_load, int VectorWidthA, class TCrA_mma, class... Ts>
  1182. CUTLASS_DEVICE void transform_A_kblock(
  1183. TCrA_load const& tCrA_load, cute::Int<VectorWidthA> vec_A,
  1184. TCrA_mma& tCrA_mma, cute::tuple<Ts...> const& partitioned_extra_info,
  1185. int const k_block) {
  1186. auto in = tCrA_load(_, _, k_block);
  1187. auto out = tCrA_mma(_, _, k_block);
  1188. if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
  1189. convert_tensor<IlvdBlkLayout>(in, out, vec_A);
  1190. } else if constexpr (ModeHasScales) {
  1191. auto tCrS = cute::get<1>(partitioned_extra_info);
  1192. auto converted_inputs =
  1193. make_fragment_like<ElementScale>(tCrA_mma)(_, _, k_block);
  1194. auto scales = tCrS(_, _, 0);
  1195. // First, we upcast the inputs to the scale type
  1196. convert_tensor<IlvdBlkLayout>(in, converted_inputs, vec_A);
  1197. // Apply scales and broadcast across inputs, store in converted_inputs
  1198. // We need to cast to nv_bfloat16 for the multiply since
  1199. // `cutlass::bfloat16_t` has an overloaded operator* that upconverts to
  1200. // float, which nvcc will not optimize to using vectorized fma
  1201. // instructions (i.e. hfma.bf16_v2)
  1202. if constexpr (std::is_same_v<ElementScale, cutlass::bfloat16_t>) {
  1203. cute::transform(
  1204. recast<nv_bfloat16>(converted_inputs), recast<nv_bfloat16>(scales),
  1205. recast<nv_bfloat16>(converted_inputs), cute::multiplies{});
  1206. } else {
  1207. cute::transform(converted_inputs, scales, converted_inputs,
  1208. cute::multiplies{});
  1209. }
  1210. // Apply zeros if required
  1211. if constexpr (KernelConversionMode ==
  1212. ConversionMode::ConvertAndScaleWithZero) {
  1213. auto tCrZ = cute::get<3>(partitioned_extra_info);
  1214. auto converted_zeros = make_fragment_like<ElementScale>(tCrZ)(_, _, 0);
  1215. convert_tensor<void>(tCrZ(_, _, 0), converted_zeros);
  1216. if constexpr (std::is_same_v<ElementScale, cutlass::bfloat16_t>) {
  1217. cute::transform(recast<nv_bfloat16>(converted_inputs),
  1218. recast<nv_bfloat16>(converted_zeros),
  1219. recast<nv_bfloat16>(converted_inputs), cute::plus{});
  1220. } else {
  1221. cute::transform(converted_inputs, converted_zeros, converted_inputs,
  1222. cute::plus{});
  1223. }
  1224. }
  1225. // Finally, we convert the scaled inputs to the mma type.
  1226. convert_tensor<void>(converted_inputs, out);
  1227. } else {
  1228. static_assert(cutlass::detail::dependent_false<KernelSchedule>,
  1229. "No A data is loaded.");
  1230. }
  1231. }
  1232. // Modified from upstream, should be kept the same when possible
  1233. // the main differences is that this version supports interleaved converts
  1234. // Utilities for transforming the A operand prior to issuing tensorcore math.
  1235. template <typename IlvdBlkLayout, class EngineIn, class EngineOut,
  1236. class TensorLayout,
  1237. int ConversionVectorWidth = cosize_v<TensorLayout>>
  1238. CUTLASS_DEVICE void convert_tensor(
  1239. Tensor<EngineIn, TensorLayout> const& in,
  1240. Tensor<EngineOut, TensorLayout>& out,
  1241. cute::Int<ConversionVectorWidth> width = {}) {
  1242. // This is an element-wise conversion where we expect both tensors to have
  1243. // the same layout. As a result, we can cast as a cutlass array to use the
  1244. // fast numeric converters without worrying about indexing into the layout.
  1245. constexpr int N = cosize_v<TensorLayout>;
  1246. // The inputs must be backed by registers & be statically sized.
  1247. static_assert(is_rmem<EngineIn>::value,
  1248. "Input tensor for A conversion must come from registers");
  1249. static_assert(is_rmem<EngineOut>::value,
  1250. "Output tensor for A conversion must come from registers");
  1251. static_assert(is_static_v<TensorLayout>,
  1252. "Tensor layout for the conversion must be static");
  1253. static_assert(cosize_v<TensorLayout> == size(TensorLayout{}),
  1254. "Cosize and size of the layout must be equal.");
  1255. static_assert(
  1256. N % ConversionVectorWidth == 0,
  1257. "Conversion vector width must divide cosize of the tensor layout.");
  1258. using SrcType = typename EngineIn::value_type;
  1259. using DstType = typename EngineOut::value_type;
  1260. using SrcArray = cutlass::Array<SrcType, ConversionVectorWidth>;
  1261. using DstArray = cutlass::Array<DstType, ConversionVectorWidth>;
  1262. constexpr cutlass::FloatRoundStyle RoundStyle =
  1263. cutlass::FloatRoundStyle::round_to_nearest;
  1264. using Converter = cutlass::InterleavedNumericArrayConverter<
  1265. IlvdBlkLayout, DstType, SrcType, ConversionVectorWidth, RoundStyle>;
  1266. constexpr int NumIterations = N / ConversionVectorWidth;
  1267. for (int ii = 0; ii < NumIterations; ++ii) {
  1268. SrcArray const* src_array_ptr =
  1269. reinterpret_cast<SrcArray const*>(raw_pointer_cast(in.data())) + ii;
  1270. DstArray* dst_array_ptr =
  1271. reinterpret_cast<DstArray*>(raw_pointer_cast(out.data())) + ii;
  1272. *dst_array_ptr = Converter::convert(*src_array_ptr);
  1273. }
  1274. }
  1275. };
  1276. } // namespace machete