scaled_mm_c3x.cu 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751
  1. // clang-format will break include orders
  2. // clang-format off
  3. #include <cudaTypedefs.h>
  4. #if defined CUDA_VERSION && CUDA_VERSION >= 12000
  5. #include <torch/all.h>
  6. #include <ATen/cuda/CUDAContext.h>
  7. #include <iostream>
  8. #include <sstream>
  9. #include <vector>
  10. #include "cutlass/cutlass.h"
  11. #include "cute/tensor.hpp"
  12. #include "cute/atom/mma_atom.hpp"
  13. #include "cutlass/numeric_types.h"
  14. #include "cutlass/gemm/device/gemm_universal_adapter.h"
  15. #include "cutlass/gemm/kernel/gemm_universal.hpp"
  16. #include "cutlass/epilogue/collective/collective_builder.hpp"
  17. #include "cutlass/gemm/collective/collective_builder.hpp"
  18. #include "broadcast_load_epilogue_c3x.hpp"
  19. #include "common.hpp"
  20. // clang-format on
  21. using namespace cute;
  22. /*
  23. This file defines quantized GEMM operations using the CUTLASS 3.x API, for
  24. NVIDIA GPUs with sm90a (Hopper) or later.
  25. Epilogue functions can be defined to post-process the output before it is
  26. written to GPU memory.
  27. Epilogues must contain a public type named EVTCompute of type Sm90EVT,
  28. as well as a static prepare_args function that constructs an
  29. EVTCompute::Arguments struct.
  30. */
  31. namespace {
  32. // A wrapper for the GEMM kernel that is used to guard against compilation on
  33. // architectures that will never use the kernel. The purpose of this is to
  34. // reduce the size of the compiled binary.
  35. // __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
  36. // into code that will be executed on the device where it is defined.
  37. template <typename Kernel>
  38. struct enable_sm90_or_later : Kernel {
  39. template <typename... Args>
  40. CUTLASS_DEVICE void operator()(Args&&... args) {
  41. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
  42. Kernel::operator()(std::forward<Args>(args)...);
  43. #endif
  44. }
  45. };
  46. /*
  47. * This class provides the common load descriptors for the
  48. * ScaledEpilogue[...] classes
  49. */
  50. template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
  51. struct ScaledEpilogueBase {
  52. protected:
  53. using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
  54. template <typename T>
  55. using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
  56. 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
  57. Stride<Int<1>, Int<0>, Int<0>>>;
  58. template <typename T>
  59. using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
  60. 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
  61. Stride<Int<0>, Int<1>, Int<0>>>;
  62. // Don't want to support nullptr by default
  63. template <typename T, bool EnableNullPtr = false>
  64. using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
  65. 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
  66. Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
  67. // Don't want to support nullptr by default
  68. template <typename T, bool EnableNullPtr = false>
  69. using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
  70. 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
  71. Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
  72. // This utility function constructs the arguments for the load descriptors
  73. // from a tensor. It can handle both row and column, as well as row/column or
  74. // scalar cases.
  75. template <typename Descriptor, typename T>
  76. static auto args_from_tensor(torch::Tensor const& tensor) {
  77. using Arguments = typename Descriptor::Arguments;
  78. auto* data_ptr = static_cast<T*>(tensor.data_ptr());
  79. if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
  80. std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
  81. return Arguments{data_ptr, tensor.numel() != 1};
  82. } else {
  83. static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
  84. !std::is_same_v<Descriptor, RowLoad<T, true>>);
  85. return Arguments{data_ptr};
  86. }
  87. }
  88. // This overload handles the case where there might not be a tensor, in which
  89. // case a nullptr is passed and a constant (0) is used.
  90. template <typename Descriptor, typename T>
  91. static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
  92. using Arguments = typename Descriptor::Arguments;
  93. auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
  94. static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
  95. std::is_same_v<Descriptor, RowLoad<T, true>>);
  96. return Arguments{data_ptr};
  97. }
  98. };
  99. /*
  100. This epilogue function defines a quantized GEMM operation similar to
  101. torch.scaled_mm_.
  102. A and B may be both either int8 or fp8_e4m3. A can be
  103. quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
  104. Any combination of per-tensor and per-row or column is supported.
  105. A and B must have symmetric quantization (zero point == 0).
  106. So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
  107. scales are applied elementwise with numpy-style broadcasting.
  108. ScaleA and ScaleB define the epilogue functions that apply the scales for
  109. the A and B operands respectively. These scales may be either per-tensor or
  110. per row or column.
  111. */
  112. template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
  113. struct ScaledEpilogue
  114. : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
  115. private:
  116. using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
  117. using Accum = typename SUPER::Accum;
  118. using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
  119. using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
  120. using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
  121. cutlass::multiplies, float, float,
  122. cutlass::FloatRoundStyle::round_to_nearest>;
  123. using EVTCompute0 =
  124. cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
  125. using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
  126. cutlass::multiplies, ElementD, float,
  127. cutlass::FloatRoundStyle::round_to_nearest>;
  128. public:
  129. using EVTCompute =
  130. cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
  131. using ArgumentType = typename EVTCompute::Arguments;
  132. static ArgumentType prepare_args(torch::Tensor const& a_scales,
  133. torch::Tensor const& b_scales) {
  134. auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
  135. auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
  136. typename EVTCompute0::Arguments evt0_args{b_args};
  137. return ArgumentType{a_args, evt0_args};
  138. }
  139. };
  140. /*
  141. * This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
  142. * This bias can also be used in the per-tensor azp case, where the activation
  143. * zero point (azp) is used to compute an azp correction term,
  144. * which is folded into the bias.
  145. *
  146. * The bias tensor must be per-output channel.
  147. * ScaleA and ScaleB can be per-tensor or per-token/per-channel.
  148. */
  149. template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
  150. struct ScaledEpilogueBias
  151. : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
  152. private:
  153. using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
  154. using Accum = typename SUPER::Accum;
  155. using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
  156. using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
  157. using Bias = typename SUPER::template RowLoad<ElementD>;
  158. using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
  159. cutlass::multiplies, float, float,
  160. cutlass::FloatRoundStyle::round_to_nearest>;
  161. using EVTCompute0 =
  162. cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
  163. using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
  164. cutlass::multiply_add, ElementD, float,
  165. cutlass::FloatRoundStyle::round_to_nearest>;
  166. public:
  167. using EVTCompute =
  168. cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
  169. using ArgumentType = typename EVTCompute::Arguments;
  170. static ArgumentType prepare_args(torch::Tensor const& a_scales,
  171. torch::Tensor const& b_scales,
  172. torch::Tensor const& bias) {
  173. auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
  174. auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
  175. auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
  176. typename EVTCompute0::Arguments evt0_args{b_args};
  177. return ArgumentType{a_args, evt0_args, bias_args};
  178. }
  179. };
  180. /*
  181. * This epilogue directly supports per-tensor azp in int32 form.
  182. * As opposed to the per-token epilogue below, this epilogue only has an azp_adj
  183. * term, which should already be multiplied with the scalar azp.
  184. * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
  185. *
  186. * This epilogue also supports bias, which remains per-channel.
  187. */
  188. template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
  189. struct ScaledEpilogueBiasAzp
  190. : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
  191. private:
  192. using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
  193. using Accum = typename SUPER::Accum;
  194. using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
  195. using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
  196. using Bias = typename SUPER::template RowLoad<ElementD, true>;
  197. // This is the full AZP term, azp * J @ B, shape (1,n)
  198. using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
  199. // Compute float(accum - azp_adj), both operands are int32_t
  200. using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
  201. cutlass::minus, float, int32_t,
  202. cutlass::FloatRoundStyle::round_to_nearest>;
  203. using EVTComputeAzp =
  204. cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
  205. using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
  206. cutlass::multiplies, float, float,
  207. cutlass::FloatRoundStyle::round_to_nearest>;
  208. using EVTComputeScaleB =
  209. cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
  210. using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
  211. cutlass::multiply_add, ElementD, float,
  212. cutlass::FloatRoundStyle::round_to_nearest>;
  213. public:
  214. using EVTCompute =
  215. cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
  216. EVTComputeScaleB, Bias>;
  217. using ArgumentType = typename EVTCompute::Arguments;
  218. static ArgumentType prepare_args(torch::Tensor const& a_scales,
  219. torch::Tensor const& b_scales,
  220. torch::Tensor const& azp_adj,
  221. c10::optional<torch::Tensor> const& bias) {
  222. auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
  223. auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
  224. auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
  225. auto azp_adj_args =
  226. SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
  227. typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
  228. typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
  229. return ArgumentType{a_args, evt_scale_b_args, bias_args};
  230. }
  231. };
  232. /*
  233. * This epilogue supports per-token azp by computing and applying
  234. * the correction term using a rank-1 update. If the term were materialized,
  235. * it would require O(m*n) space, and this way it only requires O(m+n) space.
  236. * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
  237. * point for each row of A.
  238. * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
  239. *
  240. * This epilogue also supports bias, which remains per-channel.
  241. */
  242. template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
  243. struct ScaledEpilogueBiasAzpToken
  244. : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
  245. private:
  246. using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
  247. using Accum = typename SUPER::Accum;
  248. using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
  249. using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
  250. using Bias = typename SUPER::template RowLoad<ElementD, true>;
  251. // Per-token azp term, shape (m,1)
  252. using Azp = typename SUPER::template ColLoad<int32_t>;
  253. // This is the AZP adjustment term, J @ B, shape (1,n)
  254. using AzpAdj = typename SUPER::template RowLoad<int32_t>;
  255. // Compute azp * azp_adj
  256. using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
  257. cutlass::multiplies, int32_t, int32_t,
  258. cutlass::FloatRoundStyle::round_to_nearest>;
  259. using EVTComputeAzp =
  260. cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
  261. // Compute float(accum - azp*azp_adj), all operands are int32_t
  262. using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
  263. cutlass::minus, float, int32_t,
  264. cutlass::FloatRoundStyle::round_to_nearest>;
  265. using EVTComputeAcc =
  266. cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
  267. using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
  268. cutlass::multiplies, float, float,
  269. cutlass::FloatRoundStyle::round_to_nearest>;
  270. using EVTComputeScaleB =
  271. cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
  272. using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
  273. cutlass::multiply_add, ElementD, float,
  274. cutlass::FloatRoundStyle::round_to_nearest>;
  275. public:
  276. using EVTCompute =
  277. cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
  278. EVTComputeScaleB, Bias>;
  279. using ArgumentType = typename EVTCompute::Arguments;
  280. static ArgumentType prepare_args(torch::Tensor const& a_scales,
  281. torch::Tensor const& b_scales,
  282. torch::Tensor const& azp_adj,
  283. torch::Tensor const& azp,
  284. c10::optional<torch::Tensor> const& bias) {
  285. auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
  286. auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
  287. auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
  288. auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
  289. auto azp_adj_args =
  290. SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
  291. typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
  292. typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
  293. typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
  294. return ArgumentType{a_args, evt_scale_b_args, bias_args};
  295. }
  296. };
  297. template <typename ElementAB_, typename ElementD_,
  298. template <typename, typename, typename> typename Epilogue_,
  299. typename TileShape, typename ClusterShape, typename KernelSchedule,
  300. typename EpilogueSchedule>
  301. struct cutlass_3x_gemm {
  302. using ElementAB = ElementAB_;
  303. using ElementD = ElementD_;
  304. using ElementAcc =
  305. typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
  306. float>::type;
  307. using EpilogueDescriptor =
  308. cutlass::epilogue::collective::detail::EpilogueDescriptor<
  309. TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
  310. ElementD, EpilogueSchedule>;
  311. using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
  312. using StrideD = Stride<int64_t, Int<1>, Int<0>>;
  313. using ElementC = void;
  314. using StrideC = StrideD;
  315. using EVTCompute = typename Epilogue::EVTCompute;
  316. using CollectiveEpilogue =
  317. typename cutlass::epilogue::collective::CollectiveBuilder<
  318. cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
  319. ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
  320. ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
  321. EpilogueSchedule, EVTCompute>::CollectiveOp;
  322. static constexpr size_t CEStorageSize =
  323. sizeof(typename CollectiveEpilogue::SharedStorage);
  324. using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
  325. static_cast<int>(CEStorageSize)>;
  326. // clang-format off
  327. using CollectiveMainloop =
  328. typename cutlass::gemm::collective::CollectiveBuilder<
  329. cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
  330. ElementAB, cutlass::layout::RowMajor, 16,
  331. ElementAB, cutlass::layout::ColumnMajor, 16,
  332. ElementAcc, TileShape, ClusterShape,
  333. Stages,
  334. KernelSchedule>::CollectiveOp;
  335. // clang-format on
  336. using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
  337. cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
  338. cutlass::gemm::PersistentScheduler>>;
  339. struct GemmKernel : public KernelType {};
  340. };
  341. template <typename Gemm, typename... EpilogueArgs>
  342. void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
  343. torch::Tensor const& b,
  344. EpilogueArgs&&... epilogue_params) {
  345. using ElementAB = typename Gemm::ElementAB;
  346. using ElementD = typename Gemm::ElementD;
  347. int32_t m = a.size(0);
  348. int32_t n = b.size(1);
  349. int32_t k = a.size(1);
  350. int64_t lda = a.stride(0);
  351. int64_t ldb = b.stride(1);
  352. int64_t ldc = out.stride(0);
  353. using StrideA = Stride<int64_t, Int<1>, int64_t>;
  354. using StrideB = Stride<int64_t, Int<1>, int64_t>;
  355. using StrideC = typename Gemm::StrideC;
  356. StrideA a_stride{lda, Int<1>{}, 0};
  357. StrideB b_stride{ldb, Int<1>{}, 0};
  358. StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
  359. using GemmKernel = typename Gemm::GemmKernel;
  360. typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
  361. auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
  362. auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
  363. typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
  364. b_stride};
  365. auto c_ptr = static_cast<ElementD*>(out.data_ptr());
  366. typename GemmKernel::EpilogueArguments epilogue_args{
  367. Gemm::Epilogue::prepare_args(
  368. std::forward<EpilogueArgs>(epilogue_params)...),
  369. c_ptr, c_stride, c_ptr, c_stride};
  370. typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
  371. prob_shape, mainloop_args, epilogue_args};
  372. // Launch the CUTLASS GEMM kernel.
  373. using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
  374. GemmOp gemm_op;
  375. CUTLASS_CHECK(gemm_op.can_implement(args));
  376. size_t workspace_size = gemm_op.get_workspace_size(args);
  377. auto const workspace_options =
  378. torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
  379. auto workspace = torch::empty(workspace_size, workspace_options);
  380. auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
  381. cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
  382. CUTLASS_CHECK(status);
  383. }
  384. template <typename InType, typename OutType,
  385. template <typename, typename, typename> typename Epilogue>
  386. struct sm90_fp8_config_default {
  387. // M in (128, inf)
  388. static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
  389. using KernelSchedule =
  390. cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
  391. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  392. using TileShape = Shape<_128, _128, _128>;
  393. using ClusterShape = Shape<_2, _1, _1>;
  394. using Cutlass3xGemm =
  395. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  396. KernelSchedule, EpilogueSchedule>;
  397. };
  398. template <typename InType, typename OutType,
  399. template <typename, typename, typename> typename Epilogue>
  400. struct sm90_fp8_config_M128 {
  401. // M in (64, 128]
  402. static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
  403. using KernelSchedule =
  404. cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
  405. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  406. using TileShape = Shape<_64, _128, _128>;
  407. using ClusterShape = Shape<_2, _1, _1>;
  408. using Cutlass3xGemm =
  409. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  410. KernelSchedule, EpilogueSchedule>;
  411. };
  412. template <typename InType, typename OutType,
  413. template <typename, typename, typename> typename Epilogue>
  414. struct sm90_fp8_config_M64 {
  415. // M in [1, 64]
  416. static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
  417. using KernelSchedule =
  418. cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
  419. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  420. using TileShape = Shape<_64, _64, _128>;
  421. using ClusterShape = Shape<_1, _8, _1>;
  422. using Cutlass3xGemm =
  423. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  424. KernelSchedule, EpilogueSchedule>;
  425. };
  426. template <typename InType, typename OutType,
  427. template <typename, typename, typename> typename Epilogue>
  428. struct sm90_int8_config_default {
  429. // For M > 128 and any N
  430. static_assert(std::is_same<InType, int8_t>());
  431. using KernelSchedule =
  432. typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
  433. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  434. using TileShape = Shape<_128, _128, _128>;
  435. using ClusterShape = Shape<_2, _1, _1>;
  436. using Cutlass3xGemm =
  437. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  438. KernelSchedule, EpilogueSchedule>;
  439. };
  440. template <typename InType, typename OutType,
  441. template <typename, typename, typename> typename Epilogue>
  442. struct sm90_int8_config_M128 {
  443. // For M in (64, 128] and any N
  444. static_assert(std::is_same<InType, int8_t>());
  445. using KernelSchedule =
  446. typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
  447. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  448. using TileShape = Shape<_64, _128, _128>;
  449. using ClusterShape = Shape<_2, _1, _1>;
  450. using Cutlass3xGemm =
  451. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  452. KernelSchedule, EpilogueSchedule>;
  453. };
  454. template <typename InType, typename OutType,
  455. template <typename, typename, typename> typename Epilogue>
  456. struct sm90_int8_config_M64 {
  457. // For M in (32, 64] and any N
  458. static_assert(std::is_same<InType, int8_t>());
  459. using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
  460. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  461. using TileShape = Shape<_64, _64, _256>;
  462. using ClusterShape = Shape<_1, _1, _1>;
  463. using Cutlass3xGemm =
  464. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  465. KernelSchedule, EpilogueSchedule>;
  466. };
  467. template <typename InType, typename OutType,
  468. template <typename, typename, typename> typename Epilogue>
  469. struct sm90_int8_config_M32_NBig {
  470. // For M in [1, 32] and N >= 8192
  471. static_assert(std::is_same<InType, int8_t>());
  472. using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
  473. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  474. using TileShape = Shape<_64, _128, _256>;
  475. using ClusterShape = Shape<_1, _4, _1>;
  476. using Cutlass3xGemm =
  477. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  478. KernelSchedule, EpilogueSchedule>;
  479. };
  480. template <typename InType, typename OutType,
  481. template <typename, typename, typename> typename Epilogue>
  482. struct sm90_int8_config_M32_NSmall {
  483. // For M in [1, 32] and N < 8192
  484. static_assert(std::is_same<InType, int8_t>());
  485. using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
  486. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  487. using TileShape = Shape<_64, _64, _256>;
  488. using ClusterShape = Shape<_1, _8, _1>;
  489. using Cutlass3xGemm =
  490. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  491. KernelSchedule, EpilogueSchedule>;
  492. };
  493. } // namespace
  494. template <typename InType, typename OutType,
  495. template <typename, typename, typename> typename Epilogue,
  496. typename... EpilogueArgs>
  497. void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
  498. torch::Tensor const& b,
  499. EpilogueArgs&&... args) {
  500. static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
  501. TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
  502. TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
  503. using Cutlass3xGemmDefault =
  504. typename sm90_fp8_config_default<InType, OutType,
  505. Epilogue>::Cutlass3xGemm;
  506. using Cutlass3xGemmM64 =
  507. typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
  508. using Cutlass3xGemmM128 =
  509. typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
  510. uint32_t const m = a.size(0);
  511. uint32_t const mp2 =
  512. std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
  513. if (mp2 <= 64) {
  514. // m in [1, 64]
  515. return cutlass_gemm_caller<Cutlass3xGemmM64>(
  516. out, a, b, std::forward<EpilogueArgs>(args)...);
  517. } else if (mp2 <= 128) {
  518. // m in (64, 128]
  519. return cutlass_gemm_caller<Cutlass3xGemmM128>(
  520. out, a, b, std::forward<EpilogueArgs>(args)...);
  521. } else {
  522. // m in (128, inf)
  523. return cutlass_gemm_caller<Cutlass3xGemmDefault>(
  524. out, a, b, std::forward<EpilogueArgs>(args)...);
  525. }
  526. }
  527. template <typename InType, typename OutType,
  528. template <typename, typename, typename> typename Epilogue,
  529. typename... EpilogueArgs>
  530. void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
  531. torch::Tensor const& b,
  532. EpilogueArgs&&... args) {
  533. static_assert(std::is_same<InType, int8_t>());
  534. TORCH_CHECK(a.dtype() == torch::kInt8);
  535. TORCH_CHECK(b.dtype() == torch::kInt8);
  536. using Cutlass3xGemmDefault =
  537. typename sm90_int8_config_default<InType, OutType,
  538. Epilogue>::Cutlass3xGemm;
  539. using Cutlass3xGemmM128 =
  540. typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
  541. using Cutlass3xGemmM64 =
  542. typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
  543. using Cutlass3xGemmM32NBig =
  544. typename sm90_int8_config_M32_NBig<InType, OutType,
  545. Epilogue>::Cutlass3xGemm;
  546. using Cutlass3xGemmM32NSmall =
  547. typename sm90_int8_config_M32_NSmall<InType, OutType,
  548. Epilogue>::Cutlass3xGemm;
  549. uint32_t const n = out.size(1);
  550. bool const is_small_n = n < 8192;
  551. uint32_t const m = a.size(0);
  552. uint32_t const mp2 =
  553. std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
  554. if (mp2 <= 32) {
  555. // m in [1, 32]
  556. if (is_small_n) {
  557. return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
  558. out, a, b, std::forward<EpilogueArgs>(args)...);
  559. } else {
  560. return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
  561. out, a, b, std::forward<EpilogueArgs>(args)...);
  562. }
  563. } else if (mp2 <= 64) {
  564. // m in (32, 64]
  565. return cutlass_gemm_caller<Cutlass3xGemmM64>(
  566. out, a, b, std::forward<EpilogueArgs>(args)...);
  567. } else if (mp2 <= 128) {
  568. // m in (64, 128]
  569. return cutlass_gemm_caller<Cutlass3xGemmM128>(
  570. out, a, b, std::forward<EpilogueArgs>(args)...);
  571. } else {
  572. // m in (128, inf)
  573. return cutlass_gemm_caller<Cutlass3xGemmDefault>(
  574. out, a, b, std::forward<EpilogueArgs>(args)...);
  575. }
  576. }
  577. template <template <typename, typename, typename> typename Epilogue,
  578. typename... EpilogueArgs>
  579. void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a,
  580. torch::Tensor const& b,
  581. EpilogueArgs&&... epilogue_args) {
  582. if (a.dtype() == torch::kInt8) {
  583. TORCH_CHECK(b.dtype() == torch::kInt8);
  584. if (out.dtype() == torch::kBFloat16) {
  585. return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
  586. Epilogue>(
  587. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  588. } else {
  589. TORCH_CHECK(out.dtype() == torch::kFloat16);
  590. return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
  591. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  592. }
  593. } else {
  594. TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
  595. TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
  596. if (out.dtype() == torch::kBFloat16) {
  597. return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
  598. cutlass::bfloat16_t, Epilogue>(
  599. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  600. } else {
  601. TORCH_CHECK(out.dtype() == torch::kFloat16);
  602. return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
  603. cutlass::half_t, Epilogue>(
  604. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  605. }
  606. }
  607. }
  608. void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
  609. torch::Tensor const& b,
  610. torch::Tensor const& a_scales,
  611. torch::Tensor const& b_scales,
  612. c10::optional<torch::Tensor> const& bias) {
  613. TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  614. TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
  615. if (bias) {
  616. TORCH_CHECK(bias->dtype() == c.dtype(),
  617. "currently bias dtype must match output dtype ", c.dtype());
  618. return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBias>(
  619. c, a, b, a_scales, b_scales, *bias);
  620. } else {
  621. return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogue>(c, a, b, a_scales,
  622. b_scales);
  623. }
  624. }
  625. void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
  626. torch::Tensor const& b,
  627. torch::Tensor const& a_scales,
  628. torch::Tensor const& b_scales,
  629. torch::Tensor const& azp_adj,
  630. c10::optional<torch::Tensor> const& azp,
  631. c10::optional<torch::Tensor> const& bias) {
  632. TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  633. TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
  634. if (azp) {
  635. return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzpToken>(
  636. out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
  637. } else {
  638. return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzp>(
  639. out, a, b, a_scales, b_scales, azp_adj, bias);
  640. }
  641. }
  642. #endif