scaled_mm_c3x.cu 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. // clang-format will break include orders
  2. // clang-format off
  3. #include <cudaTypedefs.h>
  4. #if defined CUDA_VERSION && CUDA_VERSION >= 12000
  5. #include <torch/all.h>
  6. #include <ATen/cuda/CUDAContext.h>
  7. #include <iostream>
  8. #include <sstream>
  9. #include <vector>
  10. #include "cutlass/cutlass.h"
  11. #include "cute/tensor.hpp"
  12. #include "cute/atom/mma_atom.hpp"
  13. #include "cutlass/numeric_types.h"
  14. #include "cutlass/util/device_memory.h"
  15. #include "cutlass/gemm/device/gemm_universal_adapter.h"
  16. #include "cutlass/gemm/kernel/gemm_universal.hpp"
  17. #include "cutlass/epilogue/collective/collective_builder.hpp"
  18. #include "cutlass/gemm/collective/collective_builder.hpp"
  19. #include "broadcast_load_epilogue_c3x.hpp"
  20. #include "common.hpp"
  21. // clang-format on
  22. using namespace cute;
  23. /*
  24. This file defines quantized GEMM operations using the CUTLASS 3.x API, for
  25. NVIDIA GPUs with sm90a (Hopper) or later.
  26. Epilogue functions can be defined to post-process the output before it is
  27. written to GPU memory.
  28. Epilogues must contain a public type named EVTCompute of type Sm90EVT,
  29. as well as a static prepare_args function that constructs an
  30. EVTCompute::Arguments struct.
  31. */
  32. namespace {
  33. uint32_t next_pow_2(uint32_t const num) {
  34. if (num <= 1) return num;
  35. return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
  36. }
  37. // A wrapper for the GEMM kernel that is used to guard against compilation on
  38. // architectures that will never use the kernel. The purpose of this is to
  39. // reduce the size of the compiled binary.
  40. // __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
  41. // into code that will be executed on the device where it is defined.
  42. template <typename Kernel>
  43. struct enable_sm90_or_later : Kernel {
  44. template <typename... Args>
  45. CUTLASS_DEVICE void operator()(Args&&... args) {
  46. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
  47. Kernel::operator()(std::forward<Args>(args)...);
  48. #endif
  49. }
  50. };
  51. /*
  52. This epilogue function defines a quantized GEMM operation similar to
  53. torch.scaled_mm_.
  54. A and B may be both either int8 or fp8_e4m3. A can be
  55. quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
  56. Any combination of per-tensor and per-row or column is supported.
  57. A and B must have symmetric quantization (zero point == 0).
  58. So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
  59. scales are applied elementwise with numpy-style broadcasting.
  60. ScaleA and ScaleB define the epilogue functions that apply the scales for
  61. the A and B operands respectively. These scales may be either per-tensor or
  62. per row or column.
  63. */
  64. template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
  65. struct ScaledEpilogue {
  66. private:
  67. using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
  68. using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
  69. 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
  70. Stride<Int<1>, Int<0>, Int<0>>>;
  71. using ScaleBDescriptor =
  72. cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
  73. EpilogueDescriptor, float>;
  74. using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
  75. ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape,
  76. typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>;
  77. using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
  78. cutlass::multiplies, float, float,
  79. cutlass::FloatRoundStyle::round_to_nearest>;
  80. using EVTCompute0 =
  81. cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
  82. using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
  83. cutlass::multiplies, ElementD, float,
  84. cutlass::FloatRoundStyle::round_to_nearest>;
  85. public:
  86. using EVTCompute =
  87. cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
  88. using ArgumentType = typename EVTCompute::Arguments;
  89. static ArgumentType prepare_args(torch::Tensor const& a_scales,
  90. torch::Tensor const& b_scales) {
  91. using ScaleA_Args = typename ScaleA::Arguments;
  92. using ScaleB_Args = typename ScaleB::Arguments;
  93. ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
  94. ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
  95. return ArgumentType{a_args, {b_args}};
  96. }
  97. };
  98. template <typename ElementAB_, typename ElementD_,
  99. template <typename, typename, typename> typename Epilogue_,
  100. typename TileShape, typename ClusterShape, typename KernelSchedule,
  101. typename EpilogueSchedule>
  102. struct cutlass_3x_gemm {
  103. using ElementAB = ElementAB_;
  104. using ElementD = ElementD_;
  105. using ElementAcc =
  106. typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
  107. float>::type;
  108. using EpilogueDescriptor =
  109. cutlass::epilogue::collective::detail::EpilogueDescriptor<
  110. TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
  111. ElementD, EpilogueSchedule>;
  112. using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
  113. using StrideD = Stride<int64_t, Int<1>, Int<0>>;
  114. using ElementC = void;
  115. using StrideC = StrideD;
  116. using EVTCompute = typename Epilogue::EVTCompute;
  117. using CollectiveEpilogue =
  118. typename cutlass::epilogue::collective::CollectiveBuilder<
  119. cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
  120. ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
  121. ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
  122. EpilogueSchedule, EVTCompute>::CollectiveOp;
  123. static constexpr size_t CEStorageSize =
  124. sizeof(typename CollectiveEpilogue::SharedStorage);
  125. using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
  126. static_cast<int>(CEStorageSize)>;
  127. // clang-format off
  128. using CollectiveMainloop =
  129. typename cutlass::gemm::collective::CollectiveBuilder<
  130. cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
  131. ElementAB, cutlass::layout::RowMajor, 16,
  132. ElementAB, cutlass::layout::ColumnMajor, 16,
  133. ElementAcc, TileShape, ClusterShape,
  134. Stages,
  135. KernelSchedule>::CollectiveOp;
  136. // clang-format on
  137. using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
  138. cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
  139. cutlass::gemm::PersistentScheduler>>;
  140. struct GemmKernel : public KernelType {};
  141. };
  142. template <typename Gemm, typename... EpilogueArgs>
  143. void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
  144. torch::Tensor const& b,
  145. EpilogueArgs&&... epilogue_params) {
  146. using ElementAB = typename Gemm::ElementAB;
  147. using ElementD = typename Gemm::ElementD;
  148. int32_t m = a.size(0);
  149. int32_t n = b.size(1);
  150. int32_t k = a.size(1);
  151. int64_t lda = a.stride(0);
  152. int64_t ldb = b.stride(1);
  153. int64_t ldc = out.stride(0);
  154. using StrideA = Stride<int64_t, Int<1>, Int<0>>;
  155. using StrideB = Stride<int64_t, Int<1>, Int<0>>;
  156. using StrideC = typename Gemm::StrideC;
  157. StrideA a_stride{lda, Int<1>{}, Int<0>{}};
  158. StrideB b_stride{ldb, Int<1>{}, Int<0>{}};
  159. StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
  160. using GemmKernel = typename Gemm::GemmKernel;
  161. typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
  162. auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
  163. auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
  164. typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
  165. b_stride};
  166. auto c_ptr = static_cast<ElementD*>(out.data_ptr());
  167. typename GemmKernel::EpilogueArguments epilogue_args{
  168. Gemm::Epilogue::prepare_args(
  169. std::forward<EpilogueArgs>(epilogue_params)...),
  170. c_ptr, c_stride, c_ptr, c_stride};
  171. typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
  172. prob_shape, mainloop_args, epilogue_args};
  173. // Launch the CUTLASS GEMM kernel.
  174. using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
  175. GemmOp gemm_op;
  176. CUTLASS_CHECK(gemm_op.can_implement(args));
  177. size_t workspace_size = gemm_op.get_workspace_size(args);
  178. cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
  179. auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
  180. cutlass::Status status = gemm_op.run(args, workspace.get(), stream);
  181. CUTLASS_CHECK(status);
  182. }
  183. template <typename InType, typename OutType,
  184. template <typename, typename, typename> typename Epilogue>
  185. struct sm90_fp8_config_default {
  186. // M in (128, inf)
  187. static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
  188. using KernelSchedule =
  189. cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
  190. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  191. using TileShape = Shape<_128, _128, _128>;
  192. using ClusterShape = Shape<_2, _1, _1>;
  193. using Cutlass3xGemm =
  194. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  195. KernelSchedule, EpilogueSchedule>;
  196. };
  197. template <typename InType, typename OutType,
  198. template <typename, typename, typename> typename Epilogue>
  199. struct sm90_fp8_config_M128 {
  200. // M in (64, 128]
  201. static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
  202. using KernelSchedule =
  203. cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
  204. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  205. using TileShape = Shape<_64, _128, _128>;
  206. using ClusterShape = Shape<_2, _1, _1>;
  207. using Cutlass3xGemm =
  208. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  209. KernelSchedule, EpilogueSchedule>;
  210. };
  211. template <typename InType, typename OutType,
  212. template <typename, typename, typename> typename Epilogue>
  213. struct sm90_fp8_config_M64 {
  214. // M in [1, 64]
  215. static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
  216. using KernelSchedule =
  217. cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
  218. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  219. using TileShape = Shape<_64, _64, _128>;
  220. using ClusterShape = Shape<_1, _8, _1>;
  221. using Cutlass3xGemm =
  222. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  223. KernelSchedule, EpilogueSchedule>;
  224. };
  225. template <typename InType, typename OutType,
  226. template <typename, typename, typename> typename Epilogue>
  227. struct sm90_int8_config_default {
  228. // For M > 128 and any N
  229. static_assert(std::is_same<InType, int8_t>());
  230. using KernelSchedule =
  231. typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
  232. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  233. using TileShape = Shape<_128, _128, _128>;
  234. using ClusterShape = Shape<_2, _1, _1>;
  235. using Cutlass3xGemm =
  236. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  237. KernelSchedule, EpilogueSchedule>;
  238. };
  239. template <typename InType, typename OutType,
  240. template <typename, typename, typename> typename Epilogue>
  241. struct sm90_int8_config_M128 {
  242. // For M in (64, 128] and any N
  243. static_assert(std::is_same<InType, int8_t>());
  244. using KernelSchedule =
  245. typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
  246. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  247. using TileShape = Shape<_64, _128, _128>;
  248. using ClusterShape = Shape<_2, _1, _1>;
  249. using Cutlass3xGemm =
  250. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  251. KernelSchedule, EpilogueSchedule>;
  252. };
  253. template <typename InType, typename OutType,
  254. template <typename, typename, typename> typename Epilogue>
  255. struct sm90_int8_config_M64 {
  256. // For M in (32, 64] and any N
  257. static_assert(std::is_same<InType, int8_t>());
  258. using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
  259. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  260. using TileShape = Shape<_64, _64, _256>;
  261. using ClusterShape = Shape<_1, _1, _1>;
  262. using Cutlass3xGemm =
  263. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  264. KernelSchedule, EpilogueSchedule>;
  265. };
  266. template <typename InType, typename OutType,
  267. template <typename, typename, typename> typename Epilogue>
  268. struct sm90_int8_config_M32_NBig {
  269. // For M in [1, 32] and N >= 8192
  270. static_assert(std::is_same<InType, int8_t>());
  271. using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
  272. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  273. using TileShape = Shape<_64, _128, _256>;
  274. using ClusterShape = Shape<_1, _4, _1>;
  275. using Cutlass3xGemm =
  276. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  277. KernelSchedule, EpilogueSchedule>;
  278. };
  279. template <typename InType, typename OutType,
  280. template <typename, typename, typename> typename Epilogue>
  281. struct sm90_int8_config_M32_NSmall {
  282. // For M in [1, 32] and N < 8192
  283. static_assert(std::is_same<InType, int8_t>());
  284. using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
  285. using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
  286. using TileShape = Shape<_64, _64, _256>;
  287. using ClusterShape = Shape<_1, _8, _1>;
  288. using Cutlass3xGemm =
  289. cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
  290. KernelSchedule, EpilogueSchedule>;
  291. };
  292. } // namespace
  293. template <typename InType, typename OutType,
  294. template <typename, typename, typename> typename Epilogue,
  295. typename... EpilogueArgs>
  296. void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
  297. torch::Tensor const& b,
  298. EpilogueArgs&&... args) {
  299. static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
  300. TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
  301. TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
  302. using Cutlass3xGemmDefault =
  303. typename sm90_fp8_config_default<InType, OutType,
  304. Epilogue>::Cutlass3xGemm;
  305. using Cutlass3xGemmM64 =
  306. typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
  307. using Cutlass3xGemmM128 =
  308. typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
  309. uint32_t const m = a.size(0);
  310. uint32_t const mp2 =
  311. std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
  312. if (mp2 <= 64) {
  313. // m in [1, 64]
  314. return cutlass_gemm_caller<Cutlass3xGemmM64>(
  315. out, a, b, std::forward<EpilogueArgs>(args)...);
  316. } else if (mp2 <= 128) {
  317. // m in (64, 128]
  318. return cutlass_gemm_caller<Cutlass3xGemmM128>(
  319. out, a, b, std::forward<EpilogueArgs>(args)...);
  320. } else {
  321. // m in (128, inf)
  322. return cutlass_gemm_caller<Cutlass3xGemmDefault>(
  323. out, a, b, std::forward<EpilogueArgs>(args)...);
  324. }
  325. }
  326. template <typename InType, typename OutType,
  327. template <typename, typename, typename> typename Epilogue,
  328. typename... EpilogueArgs>
  329. void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
  330. torch::Tensor const& b,
  331. EpilogueArgs&&... args) {
  332. static_assert(std::is_same<InType, int8_t>());
  333. TORCH_CHECK(a.dtype() == torch::kInt8);
  334. TORCH_CHECK(b.dtype() == torch::kInt8);
  335. using Cutlass3xGemmDefault =
  336. typename sm90_int8_config_default<InType, OutType,
  337. Epilogue>::Cutlass3xGemm;
  338. using Cutlass3xGemmM128 =
  339. typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
  340. using Cutlass3xGemmM64 =
  341. typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
  342. using Cutlass3xGemmM32NBig =
  343. typename sm90_int8_config_M32_NBig<InType, OutType,
  344. Epilogue>::Cutlass3xGemm;
  345. using Cutlass3xGemmM32NSmall =
  346. typename sm90_int8_config_M32_NSmall<InType, OutType,
  347. Epilogue>::Cutlass3xGemm;
  348. uint32_t const n = out.size(1);
  349. bool const is_small_n = n < 8192;
  350. uint32_t const m = a.size(0);
  351. uint32_t const mp2 =
  352. std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
  353. if (mp2 <= 32) {
  354. // m in [1, 32]
  355. if (is_small_n) {
  356. return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
  357. out, a, b, std::forward<EpilogueArgs>(args)...);
  358. } else {
  359. return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
  360. out, a, b, std::forward<EpilogueArgs>(args)...);
  361. }
  362. } else if (mp2 <= 64) {
  363. // m in (32, 64]
  364. return cutlass_gemm_caller<Cutlass3xGemmM64>(
  365. out, a, b, std::forward<EpilogueArgs>(args)...);
  366. } else if (mp2 <= 128) {
  367. // m in (64, 128]
  368. return cutlass_gemm_caller<Cutlass3xGemmM128>(
  369. out, a, b, std::forward<EpilogueArgs>(args)...);
  370. } else {
  371. // m in (128, inf)
  372. return cutlass_gemm_caller<Cutlass3xGemmDefault>(
  373. out, a, b, std::forward<EpilogueArgs>(args)...);
  374. }
  375. }
  376. void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
  377. torch::Tensor const& b,
  378. torch::Tensor const& a_scales,
  379. torch::Tensor const& b_scales) {
  380. TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  381. TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
  382. if (a.dtype() == torch::kInt8) {
  383. TORCH_CHECK(b.dtype() == torch::kInt8);
  384. if (out.dtype() == torch::kBFloat16) {
  385. return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
  386. ScaledEpilogue>(
  387. out, a, b, a_scales, b_scales);
  388. } else {
  389. TORCH_CHECK(out.dtype() == torch::kFloat16);
  390. return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t,
  391. ScaledEpilogue>(
  392. out, a, b, a_scales, b_scales);
  393. }
  394. } else {
  395. TORCH_CHECK(out.dtype() == torch::kFloat16);
  396. return cutlass_gemm_caller<
  397. cutlass_3x_gemm<int8_t, cutlass::half_t, ScaledEpilogue, TileShape,
  398. ClusterShape, KernelSchedule, EpilogueSchedule>>(
  399. out, a, b, a_scales, b_scales);
  400. }
  401. }
  402. else {
  403. TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
  404. TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
  405. if (out.dtype() == torch::kBFloat16) {
  406. return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
  407. cutlass::bfloat16_t, ScaledEpilogue>(
  408. out, a, b, a_scales, b_scales);
  409. } else {
  410. TORCH_CHECK(out.dtype() == torch::kFloat16);
  411. return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
  412. cutlass::half_t, ScaledEpilogue>(
  413. out, a, b, a_scales, b_scales);
  414. }
  415. }
  416. }
  417. #endif