scaled_mm_c2x.cu 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609
  1. #include <stddef.h>
  2. #include <torch/all.h>
  3. #include <ATen/cuda/CUDAContext.h>
  4. // clang-format will break include orders
  5. // clang-format off
  6. #include "cute/tensor.hpp"
  7. #include "cute/atom/mma_atom.hpp"
  8. #include "cutlass/numeric_types.h"
  9. #include "cutlass/util/device_memory.h"
  10. #include "cutlass/cutlass.h"
  11. #include "cutlass/gemm_coord.h"
  12. #include "cutlass/arch/mma_sm75.h"
  13. #include "cutlass/arch/arch.h"
  14. #include "cutlass/arch/mma.h"
  15. #include "cutlass/gemm/device/gemm.h"
  16. #include "cutlass/gemm/device/gemm_universal_adapter.h"
  17. #include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
  18. #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
  19. #include "broadcast_load_epilogue_c2x.hpp"
  20. #include "common.hpp"
  21. // clang-format on
  22. using namespace cute;
  23. /*
  24. This file defines quantized GEMM operations using the CUTLASS 2.x API, for
  25. NVIDIA GPUs with SM versions prior to sm90 (Hopper).
  26. Epilogue functions can be defined to post-process the output before it is
  27. written to GPU memory.
  28. Epilogues must contain a public type named EVTCompute of type Sm80EVT,
  29. as well as a static prepare_args function that constructs an
  30. EVTCompute::Arguments struct.
  31. */
  32. namespace {
  33. // Wrappers for the GEMM kernel that is used to guard against compilation on
  34. // architectures that will never use the kernel. The purpose of this is to
  35. // reduce the size of the compiled binary.
  36. // __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
  37. // into code that will be executed on the device where it is defined.
  38. template <typename Kernel>
  39. struct enable_sm75_to_sm80 : Kernel {
  40. template <typename... Args>
  41. CUTLASS_DEVICE static void invoke(Args&&... args) {
  42. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
  43. Kernel::invoke(std::forward<Args>(args)...);
  44. #endif
  45. }
  46. };
  47. template <typename Kernel>
  48. struct enable_sm80_to_sm89 : Kernel {
  49. template <typename... Args>
  50. CUTLASS_DEVICE static void invoke(Args&&... args) {
  51. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
  52. Kernel::invoke(std::forward<Args>(args)...);
  53. #endif
  54. }
  55. };
  56. template <typename Kernel>
  57. struct enable_sm89_to_sm90 : Kernel {
  58. template <typename... Args>
  59. CUTLASS_DEVICE static void invoke(Args&&... args) {
  60. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
  61. Kernel::invoke(std::forward<Args>(args)...);
  62. #endif
  63. }
  64. };
  65. /*
  66. * This class provides the common ScaleA and ScaleB descriptors for the
  67. * ScaledEpilogue and ScaledEpilogueBias classes.
  68. */
  69. template <typename ElementD, typename OutputTileThreadMap>
  70. struct ScaledEpilogueBase {
  71. protected:
  72. using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
  73. using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
  74. OutputTileThreadMap, float, Stride<Int<1>, Int<0>, Int<0>>>;
  75. using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
  76. OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;
  77. };
  78. /*
  79. This epilogue function defines a quantized GEMM operation similar to
  80. torch._scaled_mm.
  81. A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
  82. per-row. B can be quantized per-tensor or per-column.
  83. Any combination of per-tensor and per-row or column is supported.
  84. A and B must have symmetric quantization (zero point == 0).
  85. So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
  86. scales are applied elementwise with numpy-style broadcasting.
  87. ScaleA and ScaleB define the epilogue functions that apply the scales for
  88. the A and B operands respectively. These scales may be either per-tensor or
  89. per row or column.
  90. */
  91. template <typename ElementD, typename OutputTileThreadMap>
  92. struct ScaledEpilogue
  93. : private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
  94. private:
  95. using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
  96. using Accum = typename SUPER::Accum;
  97. using ScaleA = typename SUPER::ScaleA;
  98. using ScaleB = typename SUPER::ScaleB;
  99. using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
  100. cutlass::multiplies, float, float,
  101. cutlass::FloatRoundStyle::round_to_nearest>;
  102. using EVTCompute0 =
  103. cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
  104. using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
  105. cutlass::multiplies, ElementD, float,
  106. cutlass::FloatRoundStyle::round_to_nearest>;
  107. public:
  108. using EVTCompute =
  109. cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
  110. using ArgumentType = typename EVTCompute::Arguments;
  111. static ArgumentType prepare_args(torch::Tensor const& a_scales,
  112. torch::Tensor const& b_scales) {
  113. using ScaleAArgs = typename ScaleA::Arguments;
  114. using ScaleBArgs = typename ScaleB::Arguments;
  115. ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
  116. ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
  117. typename EVTCompute0::Arguments evt0_compute_args{b_args};
  118. typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args};
  119. return evt_compute_args;
  120. }
  121. };
  122. template <typename ElementD, typename OutputTileThreadMap>
  123. struct ScaledEpilogueBias
  124. : private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
  125. private:
  126. using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
  127. using Accum = typename SUPER::Accum;
  128. using ScaleA = typename SUPER::ScaleA;
  129. using ScaleB = typename SUPER::ScaleB;
  130. using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
  131. cutlass::multiplies, float, float,
  132. cutlass::FloatRoundStyle::round_to_nearest>;
  133. using EVTCompute0 =
  134. cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
  135. using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
  136. cutlass::multiply_add, ElementD, float,
  137. cutlass::FloatRoundStyle::round_to_nearest>;
  138. using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
  139. OutputTileThreadMap, ElementD, Stride<Int<0>, Int<1>, Int<0>>>;
  140. public:
  141. using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
  142. EVTCompute0, Bias>;
  143. using ArgumentType = typename EVTCompute::Arguments;
  144. static ArgumentType prepare_args(torch::Tensor const& a_scales,
  145. torch::Tensor const& b_scales,
  146. torch::Tensor const& bias) {
  147. using ScaleAArgs = typename ScaleA::Arguments;
  148. using ScaleBArgs = typename ScaleB::Arguments;
  149. using BiasArgs = typename Bias::Arguments;
  150. ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
  151. ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
  152. BiasArgs bias_args{static_cast<ElementD*>(bias.data_ptr()), {}};
  153. typename EVTCompute0::Arguments evt0_compute_args{b_args};
  154. typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args,
  155. bias_args};
  156. return evt_compute_args;
  157. }
  158. };
  159. template <typename Arch, template <typename> typename ArchGuard,
  160. typename ElementAB_, typename ElementD_,
  161. template <typename, typename> typename Epilogue_, typename TileShape,
  162. typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
  163. struct cutlass_2x_gemm {
  164. using ElementAB = ElementAB_;
  165. using ElementD = ElementD_;
  166. using ElementAcc =
  167. typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
  168. float>::type;
  169. using Operator =
  170. typename std::conditional<std::is_same_v<ElementAB, int8_t>,
  171. cutlass::arch::OpMultiplyAddSaturate,
  172. cutlass::arch::OpMultiplyAdd>::type;
  173. using OutputTileThreadMap =
  174. cutlass::epilogue::threadblock::OutputTileThreadLayout<
  175. TileShape, WarpShape, float, 4, 1 /* epilogue stages */
  176. >;
  177. using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
  178. using EVTCompute = typename Epilogue::EVTCompute;
  179. using D = cutlass::epilogue::threadblock::VisitorAuxStore<
  180. OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
  181. Stride<int64_t, Int<1>, Int<0>>>;
  182. using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
  183. // clang-format off
  184. using RowMajor = typename cutlass::layout::RowMajor;
  185. using ColumnMajor = typename cutlass::layout::ColumnMajor;
  186. using KernelType =
  187. ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
  188. ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
  189. ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
  190. float, cutlass::layout::RowMajor, 4,
  191. ElementAcc, float, cutlass::arch::OpClassTensorOp,
  192. Arch,
  193. TileShape, WarpShape, InstructionShape,
  194. EVTD,
  195. cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
  196. MainLoopStages, Operator,
  197. 1 /* epilogue stages */
  198. >::GemmKernel>;
  199. // clang-format on
  200. using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
  201. };
  202. template <typename Gemm, typename... EpilogueArgs>
  203. void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
  204. torch::Tensor const& b,
  205. EpilogueArgs&&... epilogue_params) {
  206. using ElementAB = typename Gemm::ElementAB;
  207. using ElementD = typename Gemm::ElementD;
  208. int32_t m = a.size(0);
  209. int32_t n = b.size(1);
  210. int32_t k = a.size(1);
  211. cutlass::gemm::GemmCoord problem_size{m, n, k};
  212. int64_t lda = a.stride(0);
  213. int64_t ldb = b.stride(1);
  214. int64_t ldc = out.stride(0);
  215. using StrideC = Stride<int64_t, Int<1>, Int<0>>;
  216. StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
  217. auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
  218. auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
  219. auto c_ptr = static_cast<ElementD*>(out.data_ptr());
  220. typename Gemm::D::Arguments d_args{c_ptr, c_stride};
  221. using Epilogue = typename Gemm::Epilogue;
  222. auto evt_args =
  223. Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...);
  224. typename Gemm::EVTD::Arguments epilogue_args{
  225. evt_args,
  226. d_args,
  227. };
  228. typename Gemm::Op::Arguments args{
  229. cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode
  230. problem_size, // problem size
  231. 1, // batch count
  232. epilogue_args,
  233. a_ptr,
  234. b_ptr,
  235. nullptr,
  236. nullptr,
  237. 0,
  238. 0,
  239. 0,
  240. 0,
  241. lda,
  242. ldb,
  243. ldc,
  244. ldc};
  245. // Launch the CUTLASS GEMM kernel.
  246. typename Gemm::Op gemm_op;
  247. size_t workspace_size = gemm_op.get_workspace_size(args);
  248. cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
  249. auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
  250. CUTLASS_CHECK(gemm_op.can_implement(args));
  251. cutlass::Status status = gemm_op(args, workspace.get(), stream);
  252. CUTLASS_CHECK(status);
  253. }
  254. template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
  255. void fallback_cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
  256. torch::Tensor const& b,
  257. EpilogueArgs&&... args) {
  258. // In some cases, the GPU isn't able to accommodate the
  259. // shared memory requirements of the Gemm. In such cases, use
  260. // the FallbackGemm instead.
  261. static const int max_shared_mem_per_block_opt_in =
  262. get_cuda_max_shared_memory_per_block_opt_in(0);
  263. size_t const gemm_shared_mem_size =
  264. sizeof(typename Gemm::KernelType::SharedStorage);
  265. size_t const fallback_gemm_shared_mem_size =
  266. sizeof(typename FallbackGemm::KernelType::SharedStorage);
  267. if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) {
  268. return cutlass_gemm_caller<Gemm>(out, a, b,
  269. std::forward<EpilogueArgs>(args)...);
  270. } else {
  271. TORCH_CHECK(fallback_gemm_shared_mem_size <=
  272. max_shared_mem_per_block_opt_in);
  273. return cutlass_gemm_caller<FallbackGemm>(
  274. out, a, b, std::forward<EpilogueArgs>(args)...);
  275. }
  276. }
  277. template <typename InType, typename OutType,
  278. template <typename, typename> typename Epilogue>
  279. struct sm80_config_default {
  280. // This config is used in 2 cases,
  281. // - M in (128, inf)
  282. // - M in (64, 128] and N >= 8192
  283. // Shared Memory required by this Gemm - 81920 bytes
  284. static_assert(std::is_same<InType, int8_t>());
  285. using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
  286. using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
  287. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  288. using Cutlass2xGemm =
  289. cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
  290. Epilogue, TileShape, WarpShape, InstructionShape, 5>;
  291. };
  292. template <typename InType, typename OutType,
  293. template <typename, typename> typename Epilogue>
  294. struct sm80_config_M64 {
  295. // This config is used in 2 cases,
  296. // - M in (32, 64]
  297. // - M in (64, 128] and N < 8192
  298. // Shared Memory required by this Gemm - 122880 bytes
  299. static_assert(std::is_same<InType, int8_t>());
  300. using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
  301. using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
  302. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  303. using Cutlass2xGemm =
  304. cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
  305. Epilogue, TileShape, WarpShape, InstructionShape, 5>;
  306. };
  307. template <typename InType, typename OutType,
  308. template <typename, typename> typename Epilogue>
  309. struct sm80_config_M32 {
  310. // M in (16, 32]
  311. // Shared Memory required by this Gemm - 61440 bytes
  312. static_assert(std::is_same<InType, int8_t>());
  313. using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
  314. using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
  315. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  316. using Cutlass2xGemm =
  317. cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
  318. Epilogue, TileShape, WarpShape, InstructionShape, 5>;
  319. };
  320. template <typename InType, typename OutType,
  321. template <typename, typename> typename Epilogue>
  322. struct sm80_config_M16 {
  323. // M in [1, 16]
  324. // Shared Memory required by this Gemm - 51200 bytes
  325. static_assert(std::is_same<InType, int8_t>());
  326. using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
  327. using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
  328. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  329. using Cutlass2xGemm =
  330. cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
  331. Epilogue, TileShape, WarpShape, InstructionShape, 5>;
  332. };
  333. } // namespace
  334. template <typename InType, typename OutType,
  335. template <typename, typename> typename Epilogue,
  336. typename... EpilogueArgs>
  337. void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a,
  338. torch::Tensor const& b,
  339. EpilogueArgs&&... args) {
  340. static_assert(std::is_same<InType, int8_t>());
  341. TORCH_CHECK(a.dtype() == torch::kInt8);
  342. TORCH_CHECK(b.dtype() == torch::kInt8);
  343. using Cutlass2xGemmDefault =
  344. typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
  345. using Cutlass2xGemmM128BigN =
  346. typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
  347. using Cutlass2xGemmM128SmallN =
  348. typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
  349. using Cutlass2xGemmM64 =
  350. typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
  351. using Cutlass2xGemmM32 =
  352. typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
  353. using Cutlass2xGemmM16 =
  354. typename sm80_config_M16<InType, OutType, Epilogue>::Cutlass2xGemm;
  355. // Due to shared memory requirements, some Gemms may fail to run on some
  356. // GPUs. As the name indicates, the Fallback Gemm is used as an alternative
  357. // in such cases.
  358. // sm80_config_M16 has the least shared-memory requirement. However,
  359. // based on some profiling, we select sm80_config_M32 as a better alternative
  360. // performance wise.
  361. using FallbackGemm =
  362. typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
  363. uint32_t const m = a.size(0);
  364. uint32_t const mp2 =
  365. std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
  366. if (mp2 <= 16) {
  367. // M in [1, 16]
  368. return fallback_cutlass_gemm_caller<Cutlass2xGemmM16, FallbackGemm>(
  369. out, a, b, std::forward<EpilogueArgs>(args)...);
  370. } else if (mp2 <= 32) {
  371. // M in (16, 32]
  372. return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
  373. out, a, b, std::forward<EpilogueArgs>(args)...);
  374. } else if (mp2 <= 64) {
  375. // M in (32, 64]
  376. return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
  377. out, a, b, std::forward<EpilogueArgs>(args)...);
  378. } else if (mp2 <= 128) {
  379. // M in (64, 128]
  380. uint32_t const n = out.size(1);
  381. bool const small_n = n < 8192;
  382. if (small_n) {
  383. return fallback_cutlass_gemm_caller<Cutlass2xGemmM128SmallN,
  384. FallbackGemm>(
  385. out, a, b, std::forward<EpilogueArgs>(args)...);
  386. } else {
  387. return fallback_cutlass_gemm_caller<Cutlass2xGemmM128BigN, FallbackGemm>(
  388. out, a, b, std::forward<EpilogueArgs>(args)...);
  389. }
  390. } else {
  391. // M in (128, inf)
  392. return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
  393. out, a, b, std::forward<EpilogueArgs>(args)...);
  394. }
  395. }
  396. template <template <typename, typename> typename Epilogue,
  397. typename... EpilogueArgs>
  398. void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
  399. torch::Tensor const& b,
  400. EpilogueArgs&&... epilogue_args) {
  401. TORCH_CHECK(a.dtype() == torch::kInt8);
  402. TORCH_CHECK(b.dtype() == torch::kInt8);
  403. using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
  404. using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
  405. using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
  406. if (out.dtype() == torch::kBFloat16) {
  407. return cutlass_gemm_caller<cutlass_2x_gemm<
  408. cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t,
  409. Epilogue, TileShape, WarpShape, InstructionShape, 2>>(
  410. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  411. } else {
  412. TORCH_CHECK(out.dtype() == torch::kFloat16);
  413. return cutlass_gemm_caller<cutlass_2x_gemm<
  414. cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
  415. Epilogue, TileShape, WarpShape, InstructionShape, 2>>(
  416. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  417. }
  418. }
  419. void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
  420. torch::Tensor const& b,
  421. torch::Tensor const& a_scales,
  422. torch::Tensor const& b_scales,
  423. c10::optional<torch::Tensor> const& bias) {
  424. TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  425. TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
  426. if (bias) {
  427. TORCH_CHECK(bias->dtype() == out.dtype(),
  428. "currently bias dtype must match output dtype ", out.dtype());
  429. return cutlass_scaled_mm_sm75_epilogue<ScaledEpilogueBias>(
  430. out, a, b, a_scales, b_scales, *bias);
  431. } else {
  432. return cutlass_scaled_mm_sm75_epilogue<ScaledEpilogue>(out, a, b, a_scales,
  433. b_scales);
  434. }
  435. }
  436. template <template <typename, typename> typename Epilogue,
  437. typename... EpilogueArgs>
  438. void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
  439. torch::Tensor const& b,
  440. EpilogueArgs&&... epilogue_args) {
  441. TORCH_CHECK(a.dtype() == torch::kInt8);
  442. TORCH_CHECK(b.dtype() == torch::kInt8);
  443. if (out.dtype() == torch::kBFloat16) {
  444. return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
  445. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  446. } else {
  447. TORCH_CHECK(out.dtype() == torch::kFloat16);
  448. return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
  449. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  450. }
  451. }
  452. void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
  453. torch::Tensor const& b,
  454. torch::Tensor const& a_scales,
  455. torch::Tensor const& b_scales,
  456. c10::optional<torch::Tensor> const& bias) {
  457. TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  458. TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
  459. if (bias) {
  460. TORCH_CHECK(bias->dtype() == out.dtype(),
  461. "currently bias dtype must match output dtype ", out.dtype());
  462. return cutlass_scaled_mm_sm80_epilogue<ScaledEpilogueBias>(
  463. out, a, b, a_scales, b_scales, *bias);
  464. } else {
  465. return cutlass_scaled_mm_sm80_epilogue<ScaledEpilogue>(out, a, b, a_scales,
  466. b_scales);
  467. }
  468. }
  469. template <template <typename, typename> typename Epilogue,
  470. typename... EpilogueArgs>
  471. void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
  472. torch::Tensor const& b,
  473. EpilogueArgs&&... epilogue_args) {
  474. using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
  475. using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
  476. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  477. if (a.dtype() == torch::kInt8) {
  478. TORCH_CHECK(b.dtype() == torch::kInt8);
  479. if (out.dtype() == torch::kBFloat16) {
  480. return cutlass_gemm_caller<cutlass_2x_gemm<
  481. cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t,
  482. Epilogue, TileShape, WarpShape, InstructionShape, 5>>(
  483. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  484. } else {
  485. assert(out.dtype() == torch::kFloat16);
  486. return cutlass_gemm_caller<cutlass_2x_gemm<
  487. cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t,
  488. Epilogue, TileShape, WarpShape, InstructionShape, 5>>(
  489. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  490. }
  491. } else {
  492. TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
  493. TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
  494. if (out.dtype() == torch::kBFloat16) {
  495. return cutlass_gemm_caller<
  496. cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
  497. cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue,
  498. TileShape, WarpShape, InstructionShape, 5>>(
  499. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  500. } else {
  501. TORCH_CHECK(out.dtype() == torch::kFloat16);
  502. return cutlass_gemm_caller<
  503. cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
  504. cutlass::float_e4m3_t, cutlass::half_t, Epilogue,
  505. TileShape, WarpShape, InstructionShape, 5>>(
  506. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  507. }
  508. }
  509. }
  510. void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
  511. torch::Tensor const& b,
  512. torch::Tensor const& a_scales,
  513. torch::Tensor const& b_scales,
  514. c10::optional<torch::Tensor> const& bias) {
  515. TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  516. TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
  517. if (bias) {
  518. TORCH_CHECK(bias->dtype() == out.dtype(),
  519. "currently bias dtype must match output dtype ", out.dtype());
  520. return cutlass_scaled_mm_sm89_epilogue<ScaledEpilogueBias>(
  521. out, a, b, a_scales, b_scales, *bias);
  522. } else {
  523. return cutlass_scaled_mm_sm89_epilogue<ScaledEpilogue>(out, a, b, a_scales,
  524. b_scales);
  525. }
  526. }