scaled_mm_c3x.cu 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. // clang-format will break include orders
  2. // clang-format off
  3. #include <cudaTypedefs.h>
  4. #if defined CUDA_VERSION && CUDA_VERSION >= 12000
  5. #include <torch/all.h>
  6. #include <ATen/cuda/CUDAContext.h>
  7. #include <iostream>
  8. #include <sstream>
  9. #include <vector>
  10. #include "cutlass/cutlass.h"
  11. #include "cute/tensor.hpp"
  12. #include "cute/atom/mma_atom.hpp"
  13. #include "cutlass/numeric_types.h"
  14. #include "cutlass/util/device_memory.h"
  15. #include "cutlass/gemm/device/gemm_universal_adapter.h"
  16. #include "cutlass/gemm/kernel/gemm_universal.hpp"
  17. #include "cutlass/epilogue/collective/collective_builder.hpp"
  18. #include "cutlass/gemm/collective/collective_builder.hpp"
  19. #include "broadcast_load_epilogue_c3x.hpp"
  20. #include "common.hpp"
  21. // clang-format on
  22. using namespace cute;
  23. /*
  24. This file defines quantized GEMM operations using the CUTLASS 3.x API, for
  25. NVIDIA GPUs with sm90a (Hopper) or later.
  26. Epilogue functions can be defined to post-process the output before it is
  27. written to GPU memory.
  28. Epilogues must contain a public type named EVTCompute of type Sm90EVT,
  29. as well as a static prepare_args function that constructs an
  30. EVTCompute::Arguments struct.
  31. */
  32. namespace {
  33. // A wrapper for the GEMM kernel that is used to guard against compilation on
  34. // architectures that will never use the kernel. The purpose of this is to
  35. // reduce the size of the compiled binary.
  36. // __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
  37. // into code that will be executed on the device where it is defined.
  38. template <typename Kernel>
  39. struct enable_sm90_or_later : Kernel {
  40. template <typename... Args>
  41. CUTLASS_DEVICE void operator()(Args&&... args) {
  42. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
  43. Kernel::operator()(std::forward<Args>(args)...);
  44. #endif
  45. }
  46. };
  47. /*
  48. * This class provides the common ScaleA and ScaleB descriptors for the
  49. * ScaledEpilogue and ScaledEpilogueBias classes.
  50. */
  51. template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
  52. struct ScaledEpilogueBase {
  53. protected:
  54. using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
  55. using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
  56. 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
  57. Stride<Int<1>, Int<0>, Int<0>>>;
  58. using ScaleBDescriptor =
  59. cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
  60. EpilogueDescriptor, float>;
  61. using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
  62. ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape,
  63. typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>;
  64. };
  65. /*
  66. This epilogue function defines a quantized GEMM operation similar to
  67. torch.scaled_mm_.
  68. A and B may be both either int8 or fp8_e4m3. A can be
  69. quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
  70. Any combination of per-tensor and per-row or column is supported.
  71. A and B must have symmetric quantization (zero point == 0).
  72. So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
  73. scales are applied elementwise with numpy-style broadcasting.
  74. ScaleA and ScaleB define the epilogue functions that apply the scales for
  75. the A and B operands respectively. These scales may be either per-tensor or
  76. per row or column.
  77. */
  78. template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
  79. struct ScaledEpilogue
  80. : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
  81. private:
  82. using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
  83. using Accum = typename SUPER::Accum;
  84. using ScaleA = typename SUPER::ScaleA;
  85. using ScaleB = typename SUPER::ScaleB;
  86. using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
  87. cutlass::multiplies, float, float,
  88. cutlass::FloatRoundStyle::round_to_nearest>;
  89. using EVTCompute0 =
  90. cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
  91. using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
  92. cutlass::multiplies, ElementD, float,
  93. cutlass::FloatRoundStyle::round_to_nearest>;
  94. public:
  95. using EVTCompute =
  96. cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
  97. using ArgumentType = typename EVTCompute::Arguments;
  98. static ArgumentType prepare_args(torch::Tensor const& a_scales,
  99. torch::Tensor const& b_scales) {
  100. using ScaleA_Args = typename ScaleA::Arguments;
  101. using ScaleB_Args = typename ScaleB::Arguments;
  102. ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
  103. ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
  104. return ArgumentType{a_args, {b_args}};
  105. }
  106. };
  107. template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
  108. struct ScaledEpilogueBias
  109. : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
  110. private:
  111. using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
  112. using Accum = typename SUPER::Accum;
  113. using ScaleA = typename SUPER::ScaleA;
  114. using ScaleB = typename SUPER::ScaleB;
  115. using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
  116. cutlass::multiplies, float, float,
  117. cutlass::FloatRoundStyle::round_to_nearest>;
  118. using EVTCompute0 =
  119. cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
  120. using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
  121. cutlass::multiply_add, ElementD, float,
  122. cutlass::FloatRoundStyle::round_to_nearest>;
  123. using BiasDescriptor =
  124. cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
  125. EpilogueDescriptor, ElementD>;
  126. using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
  127. BiasDescriptor::Stages, typename EpilogueDescriptor::TileShape, ElementD,
  128. Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<ElementD>, false>;
  129. public:
  130. using EVTCompute =
  131. cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
  132. using ArgumentType = typename EVTCompute::Arguments;
  133. static ArgumentType prepare_args(torch::Tensor const& a_scales,
  134. torch::Tensor const& b_scales,
  135. torch::Tensor const& bias) {
  136. using ScaleA_Args = typename ScaleA::Arguments;
  137. using ScaleB_Args = typename ScaleB::Arguments;
  138. using Bias_Args = typename Bias::Arguments;
  139. ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
  140. ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
  141. Bias_Args bias_args{static_cast<ElementD*>(bias.data_ptr())};
  142. return ArgumentType{a_args, {b_args}, bias_args};
  143. }
  144. };
  145. template <typename ElementAB_, typename ElementD_,
  146. template <typename, typename, typename> typename Epilogue_,
  147. typename TileShape, typename ClusterShape, typename KernelSchedule,
  148. typename EpilogueSchedule>
  149. struct cutlass_3x_gemm {
  150. using ElementAB = ElementAB_;
  151. using ElementD = ElementD_;
  152. using ElementAcc =
  153. typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
  154. float>::type;
  155. using EpilogueDescriptor =
  156. cutlass::epilogue::collective::detail::EpilogueDescriptor<
  157. TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
  158. ElementD, EpilogueSchedule>;
  159. using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
  160. using StrideD = Stride<int64_t, Int<1>, Int<0>>;
  161. using ElementC = void;
  162. using StrideC = StrideD;
  163. using EVTCompute = typename Epilogue::EVTCompute;
  164. using CollectiveEpilogue =
  165. typename cutlass::epilogue::collective::CollectiveBuilder<
  166. cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
  167. ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
  168. ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
  169. EpilogueSchedule, EVTCompute>::CollectiveOp;
  170. static constexpr size_t CEStorageSize =
  171. sizeof(typename CollectiveEpilogue::SharedStorage);
  172. using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
  173. static_cast<int>(CEStorageSize)>;
  174. // clang-format off
  175. using CollectiveMainloop =
  176. typename cutlass::gemm::collective::CollectiveBuilder<
  177. cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
  178. ElementAB, cutlass::layout::RowMajor, 16,
  179. ElementAB, cutlass::layout::ColumnMajor, 16,
  180. ElementAcc, TileShape, ClusterShape,
  181. Stages,
  182. KernelSchedule>::CollectiveOp;
  183. // clang-format on
  184. using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
  185. cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
  186. cutlass::gemm::PersistentScheduler>>;
  187. struct GemmKernel : public KernelType {};
  188. };
  189. template <typename Gemm, typename... EpilogueArgs>
  190. void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
  191. torch::Tensor const& b,
  192. EpilogueArgs&&... epilogue_params) {
  193. using ElementAB = typename Gemm::ElementAB;
  194. using ElementD = typename Gemm::ElementD;
  195. int32_t m = a.size(0);
  196. int32_t n = b.size(1);
  197. int32_t k = a.size(1);
  198. int64_t lda = a.stride(0);
  199. int64_t ldb = b.stride(1);
  200. int64_t ldc = out.stride(0);
  201. using StrideA = Stride<int64_t, Int<1>, Int<0>>;
  202. using StrideB = Stride<int64_t, Int<1>, Int<0>>;
  203. using StrideC = typename Gemm::StrideC;
  204. StrideA a_stride{lda, Int<1>{}, Int<0>{}};
  205. StrideB b_stride{ldb, Int<1>{}, Int<0>{}};
  206. StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
  207. using GemmKernel = typename Gemm::GemmKernel;
  208. typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
  209. auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
  210. auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
  211. typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
  212. b_stride};
  213. auto c_ptr = static_cast<ElementD*>(out.data_ptr());
  214. typename GemmKernel::EpilogueArguments epilogue_args{
  215. Gemm::Epilogue::prepare_args(
  216. std::forward<EpilogueArgs>(epilogue_params)...),
  217. c_ptr, c_stride, c_ptr, c_stride};
  218. typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
  219. prob_shape, mainloop_args, epilogue_args};
  220. // Launch the CUTLASS GEMM kernel.
  221. using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
  222. GemmOp gemm_op;
  223. CUTLASS_CHECK(gemm_op.can_implement(args));
  224. size_t workspace_size = gemm_op.get_workspace_size(args);
  225. cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
  226. auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
  227. cutlass::Status status = gemm_op.run(args, workspace.get(), stream);
  228. CUTLASS_CHECK(status);
  229. }
  230. template <typename InType, typename OutType,
  231. template <typename, typename, typename> typename Epilogue>
  232. struct sm90_fp8_config_default {
  233. // M in (128, inf)
  234. static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
  235. using KernelSchedule =
  236. cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
  237. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  238. using TileShape = Shape<_128, _128, _128>;
  239. using ClusterShape = Shape<_2, _1, _1>;
  240. using Cutlass3xGemm =
  241. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  242. KernelSchedule, EpilogueSchedule>;
  243. };
  244. template <typename InType, typename OutType,
  245. template <typename, typename, typename> typename Epilogue>
  246. struct sm90_fp8_config_M128 {
  247. // M in (64, 128]
  248. static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
  249. using KernelSchedule =
  250. cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
  251. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  252. using TileShape = Shape<_64, _128, _128>;
  253. using ClusterShape = Shape<_2, _1, _1>;
  254. using Cutlass3xGemm =
  255. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  256. KernelSchedule, EpilogueSchedule>;
  257. };
  258. template <typename InType, typename OutType,
  259. template <typename, typename, typename> typename Epilogue>
  260. struct sm90_fp8_config_M64 {
  261. // M in [1, 64]
  262. static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
  263. using KernelSchedule =
  264. cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
  265. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  266. using TileShape = Shape<_64, _64, _128>;
  267. using ClusterShape = Shape<_1, _8, _1>;
  268. using Cutlass3xGemm =
  269. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  270. KernelSchedule, EpilogueSchedule>;
  271. };
  272. template <typename InType, typename OutType,
  273. template <typename, typename, typename> typename Epilogue>
  274. struct sm90_int8_config_default {
  275. // For M > 128 and any N
  276. static_assert(std::is_same<InType, int8_t>());
  277. using KernelSchedule =
  278. typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
  279. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  280. using TileShape = Shape<_128, _128, _128>;
  281. using ClusterShape = Shape<_2, _1, _1>;
  282. using Cutlass3xGemm =
  283. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  284. KernelSchedule, EpilogueSchedule>;
  285. };
  286. template <typename InType, typename OutType,
  287. template <typename, typename, typename> typename Epilogue>
  288. struct sm90_int8_config_M128 {
  289. // For M in (64, 128] and any N
  290. static_assert(std::is_same<InType, int8_t>());
  291. using KernelSchedule =
  292. typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
  293. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  294. using TileShape = Shape<_64, _128, _128>;
  295. using ClusterShape = Shape<_2, _1, _1>;
  296. using Cutlass3xGemm =
  297. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  298. KernelSchedule, EpilogueSchedule>;
  299. };
  300. template <typename InType, typename OutType,
  301. template <typename, typename, typename> typename Epilogue>
  302. struct sm90_int8_config_M64 {
  303. // For M in (32, 64] and any N
  304. static_assert(std::is_same<InType, int8_t>());
  305. using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
  306. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  307. using TileShape = Shape<_64, _64, _256>;
  308. using ClusterShape = Shape<_1, _1, _1>;
  309. using Cutlass3xGemm =
  310. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  311. KernelSchedule, EpilogueSchedule>;
  312. };
  313. template <typename InType, typename OutType,
  314. template <typename, typename, typename> typename Epilogue>
  315. struct sm90_int8_config_M32_NBig {
  316. // For M in [1, 32] and N >= 8192
  317. static_assert(std::is_same<InType, int8_t>());
  318. using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
  319. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  320. using TileShape = Shape<_64, _128, _256>;
  321. using ClusterShape = Shape<_1, _4, _1>;
  322. using Cutlass3xGemm =
  323. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  324. KernelSchedule, EpilogueSchedule>;
  325. };
  326. template <typename InType, typename OutType,
  327. template <typename, typename, typename> typename Epilogue>
  328. struct sm90_int8_config_M32_NSmall {
  329. // For M in [1, 32] and N < 8192
  330. static_assert(std::is_same<InType, int8_t>());
  331. using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
  332. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  333. using TileShape = Shape<_64, _64, _256>;
  334. using ClusterShape = Shape<_1, _8, _1>;
  335. using Cutlass3xGemm =
  336. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  337. KernelSchedule, EpilogueSchedule>;
  338. };
  339. } // namespace
  340. template <typename InType, typename OutType,
  341. template <typename, typename, typename> typename Epilogue,
  342. typename... EpilogueArgs>
  343. void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
  344. torch::Tensor const& b,
  345. EpilogueArgs&&... args) {
  346. static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
  347. TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
  348. TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
  349. using Cutlass3xGemmDefault =
  350. typename sm90_fp8_config_default<InType, OutType,
  351. Epilogue>::Cutlass3xGemm;
  352. using Cutlass3xGemmM64 =
  353. typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
  354. using Cutlass3xGemmM128 =
  355. typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
  356. uint32_t const m = a.size(0);
  357. uint32_t const mp2 =
  358. std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
  359. if (mp2 <= 64) {
  360. // m in [1, 64]
  361. return cutlass_gemm_caller<Cutlass3xGemmM64>(
  362. out, a, b, std::forward<EpilogueArgs>(args)...);
  363. } else if (mp2 <= 128) {
  364. // m in (64, 128]
  365. return cutlass_gemm_caller<Cutlass3xGemmM128>(
  366. out, a, b, std::forward<EpilogueArgs>(args)...);
  367. } else {
  368. // m in (128, inf)
  369. return cutlass_gemm_caller<Cutlass3xGemmDefault>(
  370. out, a, b, std::forward<EpilogueArgs>(args)...);
  371. }
  372. }
  373. template <typename InType, typename OutType,
  374. template <typename, typename, typename> typename Epilogue,
  375. typename... EpilogueArgs>
  376. void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
  377. torch::Tensor const& b,
  378. EpilogueArgs&&... args) {
  379. static_assert(std::is_same<InType, int8_t>());
  380. TORCH_CHECK(a.dtype() == torch::kInt8);
  381. TORCH_CHECK(b.dtype() == torch::kInt8);
  382. using Cutlass3xGemmDefault =
  383. typename sm90_int8_config_default<InType, OutType,
  384. Epilogue>::Cutlass3xGemm;
  385. using Cutlass3xGemmM128 =
  386. typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
  387. using Cutlass3xGemmM64 =
  388. typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
  389. using Cutlass3xGemmM32NBig =
  390. typename sm90_int8_config_M32_NBig<InType, OutType,
  391. Epilogue>::Cutlass3xGemm;
  392. using Cutlass3xGemmM32NSmall =
  393. typename sm90_int8_config_M32_NSmall<InType, OutType,
  394. Epilogue>::Cutlass3xGemm;
  395. uint32_t const n = out.size(1);
  396. bool const is_small_n = n < 8192;
  397. uint32_t const m = a.size(0);
  398. uint32_t const mp2 =
  399. std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
  400. if (mp2 <= 32) {
  401. // m in [1, 32]
  402. if (is_small_n) {
  403. return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
  404. out, a, b, std::forward<EpilogueArgs>(args)...);
  405. } else {
  406. return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
  407. out, a, b, std::forward<EpilogueArgs>(args)...);
  408. }
  409. } else if (mp2 <= 64) {
  410. // m in (32, 64]
  411. return cutlass_gemm_caller<Cutlass3xGemmM64>(
  412. out, a, b, std::forward<EpilogueArgs>(args)...);
  413. } else if (mp2 <= 128) {
  414. // m in (64, 128]
  415. return cutlass_gemm_caller<Cutlass3xGemmM128>(
  416. out, a, b, std::forward<EpilogueArgs>(args)...);
  417. } else {
  418. // m in (128, inf)
  419. return cutlass_gemm_caller<Cutlass3xGemmDefault>(
  420. out, a, b, std::forward<EpilogueArgs>(args)...);
  421. }
  422. }
  423. template <template <typename, typename, typename> typename Epilogue,
  424. typename... EpilogueArgs>
  425. void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a,
  426. torch::Tensor const& b,
  427. EpilogueArgs&&... epilogue_args) {
  428. if (a.dtype() == torch::kInt8) {
  429. TORCH_CHECK(b.dtype() == torch::kInt8);
  430. if (out.dtype() == torch::kBFloat16) {
  431. return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
  432. Epilogue>(
  433. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  434. } else {
  435. TORCH_CHECK(out.dtype() == torch::kFloat16);
  436. return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
  437. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  438. }
  439. } else {
  440. TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
  441. TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
  442. if (out.dtype() == torch::kBFloat16) {
  443. return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
  444. cutlass::bfloat16_t, Epilogue>(
  445. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  446. } else {
  447. TORCH_CHECK(out.dtype() == torch::kFloat16);
  448. return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
  449. cutlass::half_t, Epilogue>(
  450. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  451. }
  452. }
  453. }
  454. void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
  455. torch::Tensor const& b,
  456. torch::Tensor const& a_scales,
  457. torch::Tensor const& b_scales,
  458. c10::optional<torch::Tensor> const& bias) {
  459. TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  460. TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
  461. if (bias) {
  462. TORCH_CHECK(bias->dtype() == c.dtype(),
  463. "currently bias dtype must match output dtype ", c.dtype());
  464. return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBias>(
  465. c, a, b, a_scales, b_scales, *bias);
  466. } else {
  467. return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogue>(c, a, b, a_scales,
  468. b_scales);
  469. }
  470. }
  471. #endif