scaled_mm_c2x_sm89_fp8_dispatch.cuh 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. #pragma once
  2. #include "scaled_mm_c2x.cuh"
  3. #include "cutlass/float8.h"
  4. /**
  5. * This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm
  6. * shape.
  7. */
  8. namespace aphrodite {
  9. template <typename InType, typename OutType,
  10. template <typename, typename> typename Epilogue>
  11. struct sm89_fp8_fallback_gemm {
  12. // Shared Memory required by this Gemm - 61440 bytes
  13. static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
  14. using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>;
  15. using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
  16. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  17. using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
  18. using Cutlass2xGemm =
  19. cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
  20. Epilogue, TileShape, WarpShape, InstructionShape, 5,
  21. FP8MathOperator>;
  22. };
  23. struct sm89_fp8_config_default {
  24. // M in (256, inf)
  25. using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
  26. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  27. using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
  28. template <typename InType, typename OutType,
  29. template <typename, typename> typename Epilogue,
  30. typename... EpilogueArgs>
  31. static void dispatch(torch::Tensor& out, torch::Tensor const& a,
  32. torch::Tensor const& b, EpilogueArgs&&... args) {
  33. static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
  34. TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
  35. using FallbackGemm =
  36. typename sm89_fp8_fallback_gemm<InType, OutType,
  37. Epilogue>::Cutlass2xGemm;
  38. uint32_t const n = out.size(1);
  39. uint32_t const np2 = next_pow_2(n);
  40. if (np2 <= 4096) {
  41. using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
  42. return aphrodite::fallback_cutlass_gemm_caller<
  43. aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
  44. aphrodite::enable_sm89_to_sm90, InType,
  45. OutType, Epilogue, TileShape, WarpShape,
  46. InstructionShape, 5, FP8MathOperator>,
  47. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  48. } else if (np2 <= 8192) {
  49. using TileShape = typename cutlass::gemm::GemmShape<256, 128, 64>;
  50. return aphrodite::fallback_cutlass_gemm_caller<
  51. aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
  52. aphrodite::enable_sm89_to_sm90, InType,
  53. OutType, Epilogue, TileShape, WarpShape,
  54. InstructionShape, 3, FP8MathOperator>,
  55. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  56. } else {
  57. using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
  58. return aphrodite::fallback_cutlass_gemm_caller<
  59. aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
  60. aphrodite::enable_sm89_to_sm90, InType,
  61. OutType, Epilogue, TileShape, WarpShape,
  62. InstructionShape, 5, FP8MathOperator>,
  63. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  64. }
  65. }
  66. };
  67. struct sm89_fp8_config_M256 {
  68. // M in (128, 256]
  69. using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
  70. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  71. using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
  72. template <typename InType, typename OutType,
  73. template <typename, typename> typename Epilogue,
  74. typename... EpilogueArgs>
  75. static void dispatch(torch::Tensor& out, torch::Tensor const& a,
  76. torch::Tensor const& b, EpilogueArgs&&... args) {
  77. static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
  78. TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
  79. using FallbackGemm =
  80. typename sm89_fp8_fallback_gemm<InType, OutType,
  81. Epilogue>::Cutlass2xGemm;
  82. uint32_t const n = out.size(1);
  83. uint32_t const np2 = next_pow_2(n);
  84. if (np2 <= 4096) {
  85. using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
  86. return aphrodite::fallback_cutlass_gemm_caller<
  87. aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
  88. aphrodite::enable_sm89_to_sm90, InType,
  89. OutType, Epilogue, TileShape, WarpShape,
  90. InstructionShape, 3, FP8MathOperator>,
  91. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  92. } else {
  93. using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
  94. return aphrodite::fallback_cutlass_gemm_caller<
  95. aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
  96. aphrodite::enable_sm89_to_sm90, InType,
  97. OutType, Epilogue, TileShape, WarpShape,
  98. InstructionShape, 5, FP8MathOperator>,
  99. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  100. }
  101. }
  102. };
  103. struct sm89_fp8_config_M128 {
  104. // M in (64, 128]
  105. using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
  106. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  107. using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
  108. template <typename InType, typename OutType,
  109. template <typename, typename> typename Epilogue,
  110. typename... EpilogueArgs>
  111. static void dispatch(torch::Tensor& out, torch::Tensor const& a,
  112. torch::Tensor const& b, EpilogueArgs&&... args) {
  113. static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
  114. TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
  115. using FallbackGemm =
  116. typename sm89_fp8_fallback_gemm<InType, OutType,
  117. Epilogue>::Cutlass2xGemm;
  118. uint32_t const n = out.size(1);
  119. uint32_t const np2 = next_pow_2(n);
  120. if (np2 <= 8192) {
  121. using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
  122. return aphrodite::fallback_cutlass_gemm_caller<
  123. aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
  124. aphrodite::enable_sm89_to_sm90, InType,
  125. OutType, Epilogue, TileShape, WarpShape,
  126. InstructionShape, 3, FP8MathOperator>,
  127. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  128. } else if (np2 <= 16384) {
  129. using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
  130. return aphrodite::fallback_cutlass_gemm_caller<
  131. aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
  132. aphrodite::enable_sm89_to_sm90, InType,
  133. OutType, Epilogue, TileShape, WarpShape,
  134. InstructionShape, 5, FP8MathOperator>,
  135. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  136. } else {
  137. using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>;
  138. return aphrodite::fallback_cutlass_gemm_caller<
  139. aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
  140. aphrodite::enable_sm89_to_sm90, InType,
  141. OutType, Epilogue, TileShape, WarpShape,
  142. InstructionShape, 3, FP8MathOperator>,
  143. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  144. }
  145. }
  146. };
  147. struct sm89_fp8_config_M64 {
  148. // M in (32, 64]
  149. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  150. template <typename InType, typename OutType,
  151. template <typename, typename> typename Epilogue,
  152. typename... EpilogueArgs>
  153. static void dispatch(torch::Tensor& out, torch::Tensor const& a,
  154. torch::Tensor const& b, EpilogueArgs&&... args) {
  155. static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
  156. TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
  157. using FallbackGemm =
  158. typename sm89_fp8_fallback_gemm<InType, OutType,
  159. Epilogue>::Cutlass2xGemm;
  160. uint32_t const n = out.size(1);
  161. uint32_t const np2 = next_pow_2(n);
  162. if (np2 <= 8196) {
  163. using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>;
  164. using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
  165. using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
  166. return aphrodite::fallback_cutlass_gemm_caller<
  167. aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
  168. aphrodite::enable_sm89_to_sm90, InType,
  169. OutType, Epilogue, TileShape, WarpShape,
  170. InstructionShape, 5, FP8MathOperator>,
  171. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  172. } else if (np2 <= 16384) {
  173. using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
  174. using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
  175. using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
  176. return aphrodite::fallback_cutlass_gemm_caller<
  177. aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
  178. aphrodite::enable_sm89_to_sm90, InType,
  179. OutType, Epilogue, TileShape, WarpShape,
  180. InstructionShape, 3, FP8MathOperator>,
  181. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  182. } else {
  183. using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>;
  184. using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
  185. using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
  186. return aphrodite::fallback_cutlass_gemm_caller<
  187. aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
  188. aphrodite::enable_sm89_to_sm90, InType,
  189. OutType, Epilogue, TileShape, WarpShape,
  190. InstructionShape, 5, FP8MathOperator>,
  191. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  192. }
  193. }
  194. };
  195. struct sm89_fp8_config_M32 {
  196. // M in (16, 32]
  197. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  198. using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
  199. template <typename InType, typename OutType,
  200. template <typename, typename> typename Epilogue,
  201. typename... EpilogueArgs>
  202. static void dispatch(torch::Tensor& out, torch::Tensor const& a,
  203. torch::Tensor const& b, EpilogueArgs&&... args) {
  204. static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
  205. TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
  206. using FallbackGemm =
  207. typename sm89_fp8_fallback_gemm<InType, OutType,
  208. Epilogue>::Cutlass2xGemm;
  209. uint32_t const n = out.size(1);
  210. uint32_t const np2 = next_pow_2(n);
  211. if (np2 <= 8192) {
  212. using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
  213. using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
  214. return aphrodite::fallback_cutlass_gemm_caller<
  215. aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
  216. aphrodite::enable_sm89_to_sm90, InType,
  217. OutType, Epilogue, TileShape, WarpShape,
  218. InstructionShape, 5, FP8MathOperator>,
  219. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  220. } else if (np2 <= 16384) {
  221. using TileShape = typename cutlass::gemm::GemmShape<32, 128, 128>;
  222. using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
  223. return aphrodite::fallback_cutlass_gemm_caller<
  224. aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
  225. aphrodite::enable_sm89_to_sm90, InType,
  226. OutType, Epilogue, TileShape, WarpShape,
  227. InstructionShape, 4, FP8MathOperator>,
  228. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  229. } else {
  230. using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
  231. using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
  232. return aphrodite::fallback_cutlass_gemm_caller<
  233. aphrodite::cutlass_2x_gemm<cutlass::arch::Sm89,
  234. aphrodite::enable_sm89_to_sm90, InType,
  235. OutType, Epilogue, TileShape, WarpShape,
  236. InstructionShape, 5, FP8MathOperator>,
  237. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  238. }
  239. }
  240. };
  241. struct sm89_fp8_config_M16 {
  242. // M in [1, 16]
  243. using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
  244. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  245. using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
  246. static const int32_t MainLoopStages = 5;
  247. template <typename InType, typename OutType,
  248. template <typename, typename> typename Epilogue,
  249. typename... EpilogueArgs>
  250. static void dispatch(torch::Tensor& out, torch::Tensor const& a,
  251. torch::Tensor const& b, EpilogueArgs&&... args) {
  252. static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
  253. TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
  254. using FallbackGemm =
  255. typename sm89_fp8_fallback_gemm<InType, OutType,
  256. Epilogue>::Cutlass2xGemm;
  257. uint32_t const n = out.size(1);
  258. uint32_t const np2 = next_pow_2(n);
  259. if (np2 <= 8192) {
  260. using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
  261. return aphrodite::fallback_cutlass_gemm_caller<
  262. aphrodite::cutlass_2x_gemm<
  263. cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
  264. OutType, Epilogue, TileShape, WarpShape, InstructionShape,
  265. MainLoopStages, FP8MathOperator>,
  266. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  267. } else if (np2 <= 24576) {
  268. using TileShape = typename cutlass::gemm::GemmShape<16, 128, 64>;
  269. return aphrodite::fallback_cutlass_gemm_caller<
  270. aphrodite::cutlass_2x_gemm<
  271. cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
  272. OutType, Epilogue, TileShape, WarpShape, InstructionShape,
  273. MainLoopStages, FP8MathOperator>,
  274. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  275. } else {
  276. using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
  277. return aphrodite::fallback_cutlass_gemm_caller<
  278. aphrodite::cutlass_2x_gemm<
  279. cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
  280. OutType, Epilogue, TileShape, WarpShape, InstructionShape,
  281. MainLoopStages, FP8MathOperator>,
  282. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  283. }
  284. }
  285. };
  286. template <typename InType, typename OutType,
  287. template <typename, typename> typename Epilogue,
  288. typename... EpilogueArgs>
  289. inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
  290. torch::Tensor const& a,
  291. torch::Tensor const& b,
  292. EpilogueArgs&&... args) {
  293. static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
  294. TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
  295. TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
  296. uint32_t const m = a.size(0);
  297. uint32_t const mp2 =
  298. std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
  299. if (mp2 <= 16) {
  300. // M in [1, 16]
  301. return sm89_fp8_config_M16::dispatch<InType, OutType, Epilogue>(
  302. out, a, b, std::forward<EpilogueArgs>(args)...);
  303. } else if (mp2 <= 32) {
  304. // M in (16, 32]
  305. return sm89_fp8_config_M32::dispatch<InType, OutType, Epilogue>(
  306. out, a, b, std::forward<EpilogueArgs>(args)...);
  307. } else if (mp2 <= 64) {
  308. // M in (32, 64]
  309. return sm89_fp8_config_M64::dispatch<InType, OutType, Epilogue>(
  310. out, a, b, std::forward<EpilogueArgs>(args)...);
  311. } else if (mp2 <= 128) {
  312. // M in (64, 128]
  313. return sm89_fp8_config_M128::dispatch<InType, OutType, Epilogue>(
  314. out, a, b, std::forward<EpilogueArgs>(args)...);
  315. } else if (mp2 <= 256) {
  316. // M in (128, 256]
  317. return sm89_fp8_config_M256::dispatch<InType, OutType, Epilogue>(
  318. out, a, b, std::forward<EpilogueArgs>(args)...);
  319. } else {
  320. // M in (256, inf)
  321. return sm89_fp8_config_default::dispatch<InType, OutType, Epilogue>(
  322. out, a, b, std::forward<EpilogueArgs>(args)...);
  323. }
  324. }
  325. } // namespace aphrodite