scaled_mm_c2x.cuh 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. #pragma once
  2. #include <stddef.h>
  3. #include <torch/all.h>
  4. #include <ATen/cuda/CUDAContext.h>
  5. // clang-format will break include orders
  6. // clang-format off
  7. #include "cute/tensor.hpp"
  8. #include "cute/atom/mma_atom.hpp"
  9. #include "cutlass/numeric_types.h"
  10. #include "cutlass/cutlass.h"
  11. #include "cutlass/gemm_coord.h"
  12. #include "cutlass/arch/mma_sm75.h"
  13. #include "cutlass/arch/arch.h"
  14. #include "cutlass/arch/mma.h"
  15. #include "cutlass/gemm/device/gemm.h"
  16. #include "cutlass/gemm/device/gemm_universal_adapter.h"
  17. #include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
  18. #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
  19. #include "broadcast_load_epilogue_c2x.hpp"
  20. #include "common.hpp"
  21. // clang-format on
  22. using namespace cute;
  23. /*
  24. Epilogue functions can be defined to post-process the output before it is
  25. written to GPU memory.
  26. Epilogues must contain a public type named EVTCompute of type Sm80EVT,
  27. as well as a static prepare_args function that constructs an
  28. EVTCompute::Arguments struct.
  29. */
  30. namespace aphrodite {
  31. // Wrappers for the GEMM kernel that is used to guard against compilation on
  32. // architectures that will never use the kernel. The purpose of this is to
  33. // reduce the size of the compiled binary.
  34. // __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
  35. // into code that will be executed on the device where it is defined.
  36. template <typename Kernel>
  37. struct enable_sm75_to_sm80 : Kernel {
  38. template <typename... Args>
  39. CUTLASS_DEVICE static void invoke(Args&&... args) {
  40. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
  41. Kernel::invoke(std::forward<Args>(args)...);
  42. #endif
  43. }
  44. };
  45. template <typename Kernel>
  46. struct enable_sm80_to_sm89 : Kernel {
  47. template <typename... Args>
  48. CUTLASS_DEVICE static void invoke(Args&&... args) {
  49. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
  50. Kernel::invoke(std::forward<Args>(args)...);
  51. #endif
  52. }
  53. };
  54. template <typename Kernel>
  55. struct enable_sm89_to_sm90 : Kernel {
  56. template <typename... Args>
  57. CUTLASS_DEVICE static void invoke(Args&&... args) {
  58. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
  59. Kernel::invoke(std::forward<Args>(args)...);
  60. #endif
  61. }
  62. };
  63. /*
  64. * This class provides the common load descriptors for the
  65. * ScaledEpilogue[...] classes
  66. */
  67. template <typename ElementD, typename OutputTileThreadMap>
  68. struct ScaledEpilogueBase {
  69. protected:
  70. using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
  71. template <typename T>
  72. using ColOrScalarLoad =
  73. cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
  74. OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
  75. template <typename T>
  76. using RowOrScalarLoad =
  77. cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
  78. OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
  79. template <typename T>
  80. using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
  81. OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
  82. template <typename T>
  83. using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
  84. OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
  85. template <typename T>
  86. using RowOrZeroLoad =
  87. cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
  88. OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
  89. // This utility function constructs the arguments for the load descriptors
  90. // from a tensor. It can handle both row and column, as well as row/column or
  91. // scalar cases.
  92. template <typename Descriptor, typename T>
  93. static auto args_from_tensor(torch::Tensor const& tensor) {
  94. using Arguments = typename Descriptor::Arguments;
  95. auto* data_ptr = static_cast<T*>(tensor.data_ptr());
  96. if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
  97. std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
  98. return Arguments{data_ptr, tensor.numel() != 1};
  99. } else {
  100. // it would technically work but no use case as data_ptr is never nullptr
  101. static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
  102. return Arguments{data_ptr};
  103. }
  104. }
  105. // This overload handles the case where there might not be a tensor, in which
  106. // case a nullptr is passed and a constant (0) is used.
  107. template <typename Descriptor, typename T>
  108. static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
  109. static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
  110. using Arguments = typename Descriptor::Arguments;
  111. auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
  112. return Arguments{data_ptr};
  113. }
  114. };
  115. /*
  116. This epilogue function defines a quantized GEMM operation similar to
  117. torch._scaled_mm.
  118. A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
  119. per-row. B can be quantized per-tensor or per-column.
  120. Any combination of per-tensor and per-row or column is supported.
  121. A and B must have symmetric quantization (zero point == 0).
  122. So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
  123. scales are applied elementwise with numpy-style broadcasting.
  124. ScaleA and ScaleB define the epilogue functions that apply the scales for
  125. the A and B operands respectively. These scales may be either per-tensor or
  126. per row or column.
  127. */
  128. template <typename ElementD, typename OutputTileThreadMap>
  129. struct ScaledEpilogue
  130. : private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
  131. private:
  132. using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
  133. using Accum = typename SUPER::Accum;
  134. using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
  135. using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
  136. using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
  137. cutlass::multiplies, float, float,
  138. cutlass::FloatRoundStyle::round_to_nearest>;
  139. using EVTCompute0 =
  140. cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
  141. using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
  142. cutlass::multiplies, ElementD, float,
  143. cutlass::FloatRoundStyle::round_to_nearest>;
  144. public:
  145. using EVTCompute =
  146. cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
  147. using ArgumentType = typename EVTCompute::Arguments;
  148. static ArgumentType prepare_args(torch::Tensor const& a_scales,
  149. torch::Tensor const& b_scales) {
  150. auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
  151. auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
  152. typename EVTCompute0::Arguments evt0_args{b_args};
  153. return ArgumentType{a_args, evt0_args};
  154. }
  155. };
  156. /*
  157. * This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
  158. * This bias can also be used in the per-tensor azp case, where the activation
  159. * zero point (azp) is used to compute an azp correction term,
  160. * which is folded into the bias.
  161. *
  162. * The bias tensor must be per-output channel.
  163. * ScaleA and ScaleB can be per-tensor or per-token/per-channel.
  164. */
  165. template <typename ElementD, typename OutputTileThreadMap>
  166. struct ScaledEpilogueBias
  167. : protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
  168. protected:
  169. using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
  170. using Accum = typename SUPER::Accum;
  171. using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
  172. using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
  173. using Bias = typename SUPER::template RowLoad<ElementD>;
  174. using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
  175. cutlass::multiplies, float, float,
  176. cutlass::FloatRoundStyle::round_to_nearest>;
  177. using EVTCompute0 =
  178. cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
  179. using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
  180. cutlass::multiply_add, ElementD, float,
  181. cutlass::FloatRoundStyle::round_to_nearest>;
  182. public:
  183. using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
  184. EVTCompute0, Bias>;
  185. using ArgumentType = typename EVTCompute::Arguments;
  186. static ArgumentType prepare_args(torch::Tensor const& a_scales,
  187. torch::Tensor const& b_scales,
  188. torch::Tensor const& bias) {
  189. auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
  190. auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
  191. auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
  192. typename EVTCompute0::Arguments evt0_args{b_args};
  193. return ArgumentType{a_args, evt0_args, bias_args};
  194. }
  195. };
  196. /*
  197. * This epilogue directly supports per-tensor azp in int32 form.
  198. * As opposed to the per-token epilogue below, this epilogue only has an azp_adj
  199. * term, which should already be multiplied with the scalar azp.
  200. * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
  201. *
  202. * This epilogue also supports bias, which remains per-channel.
  203. */
  204. template <typename ElementD, typename OutputTileThreadMap>
  205. struct ScaledEpilogueBiasAzp
  206. : protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
  207. private:
  208. using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
  209. using Accum = typename SUPER::Accum;
  210. using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
  211. using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
  212. using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
  213. // This is the full AZP term, azp * J @ B, shape (1,n)
  214. using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
  215. // Compute float(accum - azp_adj), both operands are int32_t
  216. using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
  217. cutlass::minus, float, int32_t,
  218. cutlass::FloatRoundStyle::round_to_nearest>;
  219. using EVTComputeAzp =
  220. cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
  221. using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
  222. cutlass::multiplies, float, float,
  223. cutlass::FloatRoundStyle::round_to_nearest>;
  224. using EVTComputeScaleB =
  225. cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
  226. EVTComputeAzp>;
  227. using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
  228. cutlass::multiply_add, ElementD, float,
  229. cutlass::FloatRoundStyle::round_to_nearest>;
  230. public:
  231. using EVTCompute =
  232. cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
  233. EVTComputeScaleB, Bias>;
  234. using ArgumentType = typename EVTCompute::Arguments;
  235. static ArgumentType prepare_args(torch::Tensor const& a_scales,
  236. torch::Tensor const& b_scales,
  237. torch::Tensor const& azp_adj,
  238. c10::optional<torch::Tensor> const& bias) {
  239. auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
  240. auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
  241. auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
  242. auto azp_adj_args =
  243. SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
  244. typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
  245. typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
  246. return ArgumentType{a_args, evt_scale_b_args, bias_args};
  247. }
  248. };
  249. /*
  250. * This epilogue supports per-token azp by computing and applying
  251. * the correction term using a rank-1 update. If the term were materialized,
  252. * it would require O(m*n) space, and this way it only requires O(m+n) space.
  253. * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
  254. * point for each row of A.
  255. * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
  256. *
  257. * This epilogue also supports bias, which remains per-channel.
  258. */
  259. template <typename ElementD, typename OutputTileThreadMap>
  260. struct ScaledEpilogueBiasAzpToken
  261. : protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
  262. private:
  263. using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
  264. using Accum = typename SUPER::Accum;
  265. using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
  266. using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
  267. using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
  268. // Per-token azp term, shape (m,1)
  269. using Azp = typename SUPER::template ColLoad<int32_t>;
  270. // This is the AZP adjustment term, J @ B, shape (1,n)
  271. using AzpAdj = typename SUPER::template RowLoad<int32_t>;
  272. // Compute azp * azp_adj
  273. using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
  274. cutlass::multiplies, int32_t, int32_t,
  275. cutlass::FloatRoundStyle::round_to_nearest>;
  276. using EVTComputeAzp =
  277. cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
  278. // Compute float(accum - azp*azp_adj), all operands are int32_t
  279. using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
  280. cutlass::minus, float, int32_t,
  281. cutlass::FloatRoundStyle::round_to_nearest>;
  282. using EVTComputeAcc =
  283. cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
  284. using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
  285. cutlass::multiplies, float, float,
  286. cutlass::FloatRoundStyle::round_to_nearest>;
  287. using EVTComputeScaleB =
  288. cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
  289. EVTComputeAcc>;
  290. using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
  291. cutlass::multiply_add, ElementD, float,
  292. cutlass::FloatRoundStyle::round_to_nearest>;
  293. public:
  294. using EVTCompute =
  295. cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
  296. EVTComputeScaleB, Bias>;
  297. using ArgumentType = typename EVTCompute::Arguments;
  298. static ArgumentType prepare_args(torch::Tensor const& a_scales,
  299. torch::Tensor const& b_scales,
  300. torch::Tensor const& azp_adj,
  301. torch::Tensor const& azp,
  302. c10::optional<torch::Tensor> const& bias) {
  303. auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
  304. auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
  305. auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
  306. auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
  307. auto azp_adj_args =
  308. SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
  309. typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
  310. typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
  311. typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
  312. return ArgumentType{a_args, evt_scale_b_args, bias_args};
  313. }
  314. };
  315. template <typename Arch, template <typename> typename ArchGuard,
  316. typename ElementAB_, typename ElementD_,
  317. template <typename, typename> typename Epilogue_, typename TileShape,
  318. typename WarpShape, typename InstructionShape, int32_t MainLoopStages,
  319. typename FP8MathOperator = cutlass::arch::OpMultiplyAdd>
  320. struct cutlass_2x_gemm {
  321. using ElementAB = ElementAB_;
  322. using ElementD = ElementD_;
  323. using ElementAcc =
  324. typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
  325. float>::type;
  326. using Operator =
  327. typename std::conditional<std::is_same_v<ElementAB, int8_t>,
  328. cutlass::arch::OpMultiplyAddSaturate,
  329. FP8MathOperator>::type;
  330. using OutputTileThreadMap =
  331. cutlass::epilogue::threadblock::OutputTileThreadLayout<
  332. TileShape, WarpShape, float, 4, 1 /* epilogue stages */
  333. >;
  334. using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
  335. using EVTCompute = typename Epilogue::EVTCompute;
  336. using D = cutlass::epilogue::threadblock::VisitorAuxStore<
  337. OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
  338. Stride<int64_t, Int<1>, Int<0>>>;
  339. using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
  340. // clang-format off
  341. using RowMajor = typename cutlass::layout::RowMajor;
  342. using ColumnMajor = typename cutlass::layout::ColumnMajor;
  343. using KernelType =
  344. ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
  345. ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
  346. ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
  347. float, cutlass::layout::RowMajor, 4,
  348. ElementAcc, float, cutlass::arch::OpClassTensorOp,
  349. Arch,
  350. TileShape, WarpShape, InstructionShape,
  351. EVTD,
  352. cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
  353. MainLoopStages, Operator,
  354. 1 /* epilogue stages */
  355. >::GemmKernel>;
  356. // clang-format on
  357. using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
  358. };
  359. template <typename Gemm, typename... EpilogueArgs>
  360. inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
  361. torch::Tensor const& b,
  362. EpilogueArgs&&... epilogue_params) {
  363. using ElementAB = typename Gemm::ElementAB;
  364. using ElementD = typename Gemm::ElementD;
  365. int32_t m = a.size(0);
  366. int32_t n = b.size(1);
  367. int32_t k = a.size(1);
  368. cutlass::gemm::GemmCoord problem_size{m, n, k};
  369. int64_t lda = a.stride(0);
  370. int64_t ldb = b.stride(1);
  371. int64_t ldc = out.stride(0);
  372. using StrideC = Stride<int64_t, Int<1>, Int<0>>;
  373. StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
  374. auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
  375. auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
  376. auto c_ptr = static_cast<ElementD*>(out.data_ptr());
  377. typename Gemm::D::Arguments d_args{c_ptr, c_stride};
  378. using Epilogue = typename Gemm::Epilogue;
  379. auto evt_args =
  380. Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...);
  381. typename Gemm::EVTD::Arguments epilogue_args{
  382. evt_args,
  383. d_args,
  384. };
  385. typename Gemm::Op::Arguments args{
  386. cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode
  387. problem_size, // problem size
  388. 1, // batch count
  389. epilogue_args,
  390. a_ptr,
  391. b_ptr,
  392. nullptr,
  393. nullptr,
  394. 0,
  395. 0,
  396. 0,
  397. 0,
  398. lda,
  399. ldb,
  400. ldc,
  401. ldc};
  402. // Launch the CUTLASS GEMM kernel.
  403. typename Gemm::Op gemm_op;
  404. size_t workspace_size = gemm_op.get_workspace_size(args);
  405. auto const workspace_options =
  406. torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
  407. auto workspace = torch::empty(workspace_size, workspace_options);
  408. auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
  409. CUTLASS_CHECK(gemm_op.can_implement(args));
  410. cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream);
  411. CUTLASS_CHECK(status);
  412. }
  413. template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
  414. inline void fallback_cutlass_gemm_caller(torch::Tensor& out,
  415. torch::Tensor const& a,
  416. torch::Tensor const& b,
  417. EpilogueArgs&&... args) {
  418. // In some cases, the GPU isn't able to accommodate the
  419. // shared memory requirements of the Gemm. In such cases, use
  420. // the FallbackGemm instead.
  421. static const int max_shared_mem_per_block_opt_in =
  422. get_cuda_max_shared_memory_per_block_opt_in(0);
  423. size_t const gemm_shared_mem_size =
  424. sizeof(typename Gemm::KernelType::SharedStorage);
  425. size_t const fallback_gemm_shared_mem_size =
  426. sizeof(typename FallbackGemm::KernelType::SharedStorage);
  427. if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) {
  428. return cutlass_gemm_caller<Gemm>(out, a, b,
  429. std::forward<EpilogueArgs>(args)...);
  430. } else {
  431. TORCH_CHECK(fallback_gemm_shared_mem_size <=
  432. max_shared_mem_per_block_opt_in);
  433. return cutlass_gemm_caller<FallbackGemm>(
  434. out, a, b, std::forward<EpilogueArgs>(args)...);
  435. }
  436. }
  437. } // namespace aphrodite