machete_prepacked_layout.cuh 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. #pragma once
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include <torch/all.h>
  5. // clang-format off
  6. // The cutlass include order matters (annoyingly)
  7. #include "cutlass/cutlass.h"
  8. #include "cute/tensor.hpp"
  9. #include "cutlass/tensor_ref.h"
  10. #include "cutlass/epilogue/collective/default_epilogue.hpp"
  11. #include "cutlass/epilogue/thread/linear_combination.h"
  12. #include "cutlass/gemm/dispatch_policy.hpp"
  13. #include "cutlass/gemm/collective/collective_builder.hpp"
  14. #include "cutlass/epilogue/collective/collective_builder.hpp"
  15. #include "cutlass/gemm/device/gemm_universal_adapter.h"
  16. #include "cutlass/gemm/kernel/gemm_universal.hpp"
  17. // clang-format on
  18. #include "cutlass_extensions/cute_utils.cuh"
  19. #include "machete_collective_builder.cuh"
  20. #include "machete_interleaving_utils.cuh"
  21. namespace machete {
  22. using namespace cute;
  23. struct IlvBlkLayoutAuto {};
  24. // This defines a prepacked layout for the B matrix, where the matrix is broken
  25. // up into PPBlockShape_NK blocks. The data within each block is then compactly
  26. // stored in memory such that when performing a TiledMMA operation with the same
  27. // shape as prepacked block, all the data for a given thread is contiguous in
  28. // memory. This allows us to use wider shared memory loads when loading B from
  29. // shared memory. The values within a thread are also potentially interlaeved
  30. // inorder to allow for more efficient upconverting.
  31. //
  32. // The contract here is that the `TiledMma` determined below matches the one
  33. // ultimately used in the kernel. (this is also why the other element types are
  34. // required along with the kernel schedule)
  35. template <typename ElementA_, typename ElementB_, typename ElementD_,
  36. typename AccumulatorT, class LayoutB, class KernelSchedule,
  37. typename IlvBlkLayout_ = IlvBlkLayoutAuto>
  38. // clang-format on
  39. struct PrepackedLayoutBTemplate {
  40. using MmaType = ElementA_;
  41. using ElementA = ElementA_;
  42. using ElementB = ElementB_;
  43. using ElementD = ElementD_;
  44. using ElementAccumulator =
  45. AccumulatorT; // Element type for internal accumulation
  46. using ElementMma = MmaType;
  47. // Only use interleaved layouts for subbyte weights, prmt instructions makes
  48. // non-interleaved layouts for 8bit+ weights efficient enough we don't need
  49. // iterleaved layouts
  50. using IlvdBlkLayout = std::conditional_t<
  51. std::is_same_v<IlvBlkLayout_, IlvBlkLayoutAuto>,
  52. std::conditional_t<sizeof_bits_v<ElementB> <= 4,
  53. decltype(get_interleaved_blk_layout<
  54. ElementB, sizeof_bits_v<ElementA>, 32>()),
  55. void>,
  56. IlvBlkLayout_>;
  57. // TODO (LucasWilkinson): compare the performance for other sizes
  58. // Prepacked block shape, smallest layout atom for loading into registers
  59. // (can contain multiple wgmma instructions worth of data in one block)
  60. // We ideally want this to be configured such that a thread can perform 128bit
  61. // loads, i.e. we amount of data associated with each thread within a
  62. // prepacked block is a multiple of 128bits, when using a cooperative sechdule
  63. // we have 256 threads working a single block at a time, this means each
  64. // thread works on `sizeof_bits_v<ElementB> * (128*64) / 256` bits of data,
  65. // for a 4bit type this would be 128bits
  66. using PPBlockShape_NK = Shape<_128, _64>;
  67. // Create the shape of the tile anticipated to be used by the GEMM kernel,
  68. // when the kernel executes we will compute `Ct = Bt * At` since the
  69. // quantized weights (B), must be the lhs operand so the flow through
  70. // registers.
  71. // The _128 here doesn't actually impact the shape of the stored tile directly
  72. // but may impact the op selected by rs_op_selector
  73. using GemmTileShape = decltype(make_shape(size<0>(PPBlockShape_NK{}), _128{},
  74. size<1>(PPBlockShape_NK{})));
  75. static constexpr cute::GMMA::Major GmmaMajorB =
  76. gmma_rs_tag_to_major_B<LayoutB>();
  77. // For coop schedules we have two warp groups cooperatively issuing wgmma
  78. // instructions so we use 2 atoms along the M dim (one for each warpgroup)
  79. using AtomLayoutMNK = cute::conditional_t<
  80. cute::is_same_v<KernelSchedule,
  81. KernelTmaWarpSpecializedCooperativeMixedInput>,
  82. Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
  83. using TiledMma = decltype(cute::make_tiled_mma(
  84. cute::GMMA::rs_op_selector<ElementMma, ElementMma, ElementAccumulator,
  85. GemmTileShape, GMMA::Major::K, GmmaMajorB>(),
  86. AtomLayoutMNK{}));
  87. // Prepacked block, (athrid, val) -> (N,K)
  88. // i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (N,K)
  89. CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_NK() {
  90. return TiledMma{}.thrfrg_A(make_layout(PPBlockShape_NK{}));
  91. }
  92. // Prepacked block, (N,K) -> (athrid, val)
  93. // i.e. (N,K) -> ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...)))
  94. CUTE_HOST_DEVICE static constexpr auto ppblock_NK_to_TV() {
  95. return right_inverse(ppblock_TV_to_NK()).with_shape(PPBlockShape_NK{});
  96. }
  97. // Prepacked block, (athrid, val) -> (storage_offset)
  98. // i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (storage_idx)
  99. CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_offset() {
  100. // Return iterleaved layout
  101. return make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
  102. }
  103. // Prepacked block, (athrid, val) -> (storage_offset)
  104. // i.e. ((ThrV,(ThrM,ThrK)),(IlvdFrgV,(RestM,RestK,...))) -> (storage_idx)
  105. CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_TV_to_offset() {
  106. auto layout_no_interleave =
  107. make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
  108. if constexpr (std::is_same_v<IlvdBlkLayout, void>) {
  109. return layout_no_interleave;
  110. } else {
  111. // interleave by transforming FrgV into interleaved blocks where each
  112. // block has the layout IlvdBlkLayout, for example if IlvdBlkLayout is
  113. // (2, 2) : (2, 1) then we get: ((2, 2), size(FrgV) / 4) : ((2, 1), 4)
  114. // if FrgV is {A, B, C, D, E, F, G, H}
  115. // then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H}
  116. auto frgV = get<1, 0>(layout_no_interleave);
  117. auto ilvdBlk = IlvdBlkLayout{};
  118. static_assert(size(frgV) % 4 == 0, "FrgV must be divisible by 4");
  119. auto ilvd_FrgV = make_layout(
  120. make_shape(shape(ilvdBlk), Int<size(frgV) / size(ilvdBlk)>{}),
  121. make_stride(stride(ilvdBlk), size(ilvdBlk)));
  122. // Return iterleaved layout
  123. return make_layout(
  124. get<0>(layout_no_interleave),
  125. make_layout(ilvd_FrgV, get<1, 1>(layout_no_interleave)));
  126. }
  127. }
  128. // Prepacked block, (M,K) -> (storage_offset)
  129. CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_NK_to_offset() {
  130. // do (M,K) -> (athrid, val) -> (storage_idx)
  131. return ppblock_ilvd_TV_to_offset().compose(ppblock_NK_to_TV());
  132. }
  133. // ((athrid, val), (BlocksN, BlocksK), L) -> (storage_idx)
  134. template <typename Shape_NKL>
  135. CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset(
  136. Shape_NKL shape_mkl) {
  137. constexpr auto block_layout = ppblock_TV_to_offset();
  138. // (BlocksN, BlocksK, L)
  139. auto blocks_shape =
  140. cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
  141. [](auto x, auto y) { return x / y; });
  142. // ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
  143. auto result = make_layout(
  144. block_layout,
  145. make_layout(blocks_shape,
  146. compact_col_major(blocks_shape, size(block_layout))));
  147. // ((athrid, val), (BlocksN, BlocksK, L))
  148. // => ((athrid, val), (BlocksN, BlocksK), L)
  149. return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
  150. }
  151. // ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx)
  152. template <typename Shape_NKL>
  153. CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset(
  154. Shape_NKL shape_mkl) {
  155. constexpr auto block_layout = ppblock_ilvd_NK_to_offset();
  156. // (BlocksN, BlocksK, L)
  157. auto blocks_shape =
  158. cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
  159. [](auto x, auto y) { return x / y; });
  160. // ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
  161. auto result = make_layout(
  162. block_layout,
  163. make_layout(blocks_shape,
  164. compact_col_major(blocks_shape, size(block_layout))));
  165. // ((athrid, val), (BlocksN, BlocksK, L)) => ((athrid, val), (BlocksN,
  166. // BlocksK), L)
  167. return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
  168. }
  169. // ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L)
  170. template <class Shape_NKL>
  171. CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) {
  172. auto tile = make_tile(make_layout(size<0>(PPBlockShape_NK{})),
  173. make_layout(size<1>(PPBlockShape_NK{})));
  174. // ((BlockN, BlockK), (BlocksN, BlocksK, L)) -> (N, K, L)
  175. auto tiled_A = zipped_divide(make_layout(shape_mkl), tile);
  176. return tiled_A.compose(ppblock_TV_to_NK(), _);
  177. }
  178. // (N, K, L) -> ((athrid, val), (BlocksN, BlocksK), L)
  179. template <class Shape_NKL>
  180. CUTE_HOST_DEVICE static auto NKL_to_TVbNbK(Shape_NKL shape_mkl) {
  181. auto TVbNbK_to_NKL_layout = TVbNbK_to_NKL(shape_mkl);
  182. return blocked_product(ppblock_NK_to_TV(),
  183. make_layout(shape<1>(TVbNbK_to_NKL_layout)));
  184. }
  185. };
  186. }; // namespace machete