machete_mm_kernel.cuh 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. #pragma once
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include <torch/all.h>
  5. // clang-format off
  6. // The cutlass include order matters (annoyingly)
  7. #include "cutlass/cutlass.h"
  8. #include "cute/tensor.hpp"
  9. #include "cutlass/tensor_ref.h"
  10. #include "cutlass/epilogue/collective/default_epilogue.hpp"
  11. #include "cutlass/epilogue/thread/linear_combination.h"
  12. #include "cutlass/gemm/dispatch_policy.hpp"
  13. #include "cutlass/gemm/collective/collective_builder.hpp"
  14. #include "cutlass/epilogue/collective/collective_builder.hpp"
  15. #include "cutlass/gemm/device/gemm_universal_adapter.h"
  16. #include "cutlass/gemm/kernel/gemm_universal.hpp"
  17. // clang-format on
  18. #include "cutlass_extensions/cute_utils.cuh"
  19. #include "cutlass_extensions/aphrodite_numeric_conversion.cuh"
  20. #include "machete_collective_builder.cuh"
  21. #include "machete_prepacked_layout.cuh"
  22. #include "machete_interleaving_utils.cuh"
  23. namespace machete {
  24. using namespace cute;
  25. // NOTE This kernel computes D = alpha * A * B + beta * C by computing
  26. // D^t = alpha * B^t * A^t + beta * C^t, this is because the wgmma
  27. // instructions only support sourcing from registers for the left-hand
  28. // operand, we want to upconvert/decompress the quantized operand in
  29. // register. Since the primary use case we want to support is Y = XW^t where
  30. // W is quantized, in this situation or right-hand operand is quantized so
  31. // we compute the transpose to move it to the left-hand side.
  32. template <typename ElementA_, typename ElementB_, typename ElementD_,
  33. typename AccumulatorT, typename ScaleT, typename ZeroT,
  34. class KernelSchedule, typename ScheduleConfig, bool with_C,
  35. bool with_scales, bool with_zeropoints>
  36. struct MacheteKernelTemplate {
  37. using MmaType = ElementA_;
  38. using ElementA = ElementA_;
  39. using ElementB = ElementB_;
  40. using ElementD = ElementD_;
  41. using ElementC = cute::conditional_t<with_C, ElementD, void>;
  42. using ElementZ = ZeroT;
  43. using ElementS = ScaleT;
  44. using ElementAccumulator =
  45. AccumulatorT; // Element type for internal accumulation
  46. using ElementCompute = AccumulatorT; // For Epilogue
  47. using BTypeTuple = cute::conditional_t<
  48. with_scales,
  49. cute::conditional_t<with_zeropoints,
  50. cute::tuple<ElementB, ElementS, ElementZ>,
  51. cute::tuple<ElementB, ElementS>>,
  52. ElementB>;
  53. using LayoutA = cutlass::layout::RowMajor;
  54. using LayoutC = cutlass::layout::RowMajor;
  55. using LayoutD = LayoutC;
  56. using LayoutScale = cutlass::layout::RowMajor;
  57. // not actually used since B has the prepacked layout, but required by cutlass
  58. using _LayoutB = cutlass::layout::ColumnMajor;
  59. // Interface strides expected by create_arguments (will get transposed)
  60. using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
  61. using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
  62. using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
  63. using StrideS = cutlass::detail::TagToStrideA_t<LayoutScale>;
  64. using StrideZ = StrideS;
  65. using LayoutA_Transpose =
  66. typename cutlass::layout::LayoutTranspose<LayoutA>::type;
  67. using LayoutC_Transpose =
  68. typename cutlass::layout::LayoutTranspose<LayoutC>::type;
  69. using LayoutD_Transpose =
  70. typename cutlass::layout::LayoutTranspose<LayoutD>::type;
  71. using ArchTag = cutlass::arch::Sm90;
  72. using OperatorClass = cutlass::arch::OpClassTensorOp;
  73. using PrepackedLayoutB =
  74. PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementD_, AccumulatorT,
  75. LayoutA_Transpose, KernelSchedule>;
  76. static int constexpr TileShapeK =
  77. 128 * 8 / cutlass::sizeof_bits<MmaType>::value;
  78. static int constexpr AlignmentA = 128 / cutlass::sizeof_bits_v<ElementA>;
  79. static int constexpr AlignmentB = 128 / cutlass::sizeof_bits_v<ElementB>;
  80. static int constexpr AlignmentC =
  81. (with_C) ? 128 / cutlass::sizeof_bits_v<ElementC> : 0;
  82. static int constexpr AlignmentD = 128 / cutlass::sizeof_bits_v<ElementD>;
  83. using TileShape = decltype(append(typename ScheduleConfig::TileShapeNM{},
  84. cute::Int<TileShapeK>{}));
  85. using ClusterShape = typename ScheduleConfig::ClusterShape;
  86. using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule;
  87. using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
  88. using TileScheduler = typename ScheduleConfig::TileScheduler;
  89. using CollectiveEpilogue =
  90. typename cutlass::epilogue::collective::CollectiveBuilder<
  91. ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
  92. ElementAccumulator, ElementAccumulator, ElementC, LayoutC_Transpose,
  93. AlignmentC, ElementD, LayoutD_Transpose, AlignmentD,
  94. EpilogueSchedule>::CollectiveOp;
  95. using CollectiveMainloop =
  96. typename cutlass::gemm::collective::APHRODITECollectiveBuilder<
  97. cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass,
  98. BTypeTuple, PrepackedLayoutB, AlignmentB, ElementA, LayoutA_Transpose,
  99. AlignmentA, ElementAccumulator, TileShape, ClusterShape,
  100. cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
  101. sizeof(typename CollectiveEpilogue::SharedStorage))>,
  102. KernelSchedule>::CollectiveOp;
  103. using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
  104. Shape<int, int, int, int>, // Indicates ProblemShape
  105. CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
  106. using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
  107. // stride_B is unused (since B is prepacked), but still required by cutlass
  108. using _StrideB = cutlass::detail::TagToStrideB_t<_LayoutB>;
  109. using Arguments = typename Gemm::Arguments;
  110. using MainloopArguments = typename GemmKernel::MainloopArguments;
  111. using EpilogueArguments = typename GemmKernel::EpilogueArguments;
  112. template <typename ShapeA, typename ShapeC, typename ShapeD, typename ShapeS,
  113. typename ShapeZ>
  114. static Arguments create_arguments(
  115. cudaStream_t stream,
  116. ElementA const* A_ptr, // A is an MxK matrix
  117. Layout<ShapeA, StrideA> const& layout_A,
  118. ElementB const* B_ptr, // B is an KxN prepacked matrix
  119. ElementD* D_ptr, // D is an MxN matrix
  120. Layout<ShapeD, StrideD> const& layout_D,
  121. ElementC const* C_ptr, // C is an MxN matrix
  122. std::optional<Layout<ShapeC, StrideC>> const& layout_C,
  123. ElementS const* S_ptr, // S is an scale_KxN matrix
  124. std::optional<Layout<ShapeS, StrideS>> const& layout_S,
  125. ElementZ const* Z_ptr, // Z is an scale_KxN matrix
  126. std::optional<Layout<ShapeZ, StrideZ>> const& layout_Z,
  127. ElementCompute alpha, ElementCompute beta,
  128. std::optional<int> maybe_group_size) {
  129. static_assert(!with_zeropoints || with_scales);
  130. int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A);
  131. int const group_size =
  132. maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
  133. int const scale_k = (K + group_size - 1) / group_size;
  134. TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
  135. TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N);
  136. if constexpr (with_C) {
  137. TORCH_CHECK(C_ptr && layout_C);
  138. } else {
  139. TORCH_CHECK(!C_ptr, "C not supported");
  140. }
  141. if constexpr (with_scales) {
  142. TORCH_CHECK(S_ptr && layout_S);
  143. TORCH_CHECK((size<0>(*layout_S) == scale_k && size<1>(*layout_S) == N));
  144. } else {
  145. TORCH_CHECK(!S_ptr, "Scales not supported");
  146. }
  147. if constexpr (with_zeropoints) {
  148. TORCH_CHECK(Z_ptr && layout_Z);
  149. TORCH_CHECK((size<0>(*layout_Z) == scale_k && size<1>(*layout_Z) == N));
  150. TORCH_CHECK(layout_S && *layout_Z == *layout_S,
  151. "Scales and zeros must have the same layout");
  152. } else {
  153. TORCH_CHECK(!Z_ptr, "Zeropoints not supported");
  154. }
  155. // Transpose A and D
  156. // A doesn't need to be transposed since cutlass expects a NxK matrix
  157. // for B (which is At)
  158. auto stride_At = layout_A.stride();
  159. auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
  160. auto stride_Ct = stride_Dt;
  161. if (layout_C) {
  162. stride_Ct = permute_layout<1, 0, 2>(*layout_C).stride();
  163. }
  164. MainloopArguments mainloop_arguments{};
  165. EpilogueArguments epilogue_arguments{
  166. {alpha, beta}, C_ptr, stride_Ct, D_ptr, stride_Dt};
  167. if constexpr (with_scales && with_zeropoints) {
  168. auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
  169. mainloop_arguments =
  170. MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
  171. S_ptr, stride_S, group_size, Z_ptr};
  172. } else if constexpr (with_scales) {
  173. auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
  174. mainloop_arguments = MainloopArguments{
  175. B_ptr, _StrideB{}, A_ptr, stride_At, S_ptr, stride_S, group_size};
  176. } else {
  177. mainloop_arguments =
  178. MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At};
  179. }
  180. return Arguments{cutlass::gemm::GemmUniversalMode::kGemm,
  181. {N, M, K, 1},
  182. mainloop_arguments,
  183. epilogue_arguments};
  184. };
  185. static size_t get_workspace_size(Arguments const& args) {
  186. return Gemm::get_workspace_size(args);
  187. }
  188. static bool can_implement(Arguments const& args) {
  189. return Gemm::can_implement(args) == cutlass::Status::kSuccess;
  190. }
  191. static void run(Arguments const& args, void* workspace, cudaStream_t stream) {
  192. Gemm gemm_op;
  193. cutlass::Status status = gemm_op.initialize(args, workspace, stream);
  194. TORCH_CHECK(status == cutlass::Status::kSuccess,
  195. "Machete kernel failed to initialize workspace");
  196. status = gemm_op.run(stream);
  197. TORCH_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed");
  198. }
  199. };
  200. }; // namespace machete