scaled_mm_c2x_sm89_int8_dispatch.cuh 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. #pragma once
  2. #include "scaled_mm_c2x.cuh"
  3. /**
  4. * This file defines Gemm kernel configurations for SM89 (int8) based on the
  5. * Gemm shape.
  6. */
  7. namespace aphrodite {
  8. template <typename InType, typename OutType,
  9. template <typename, typename> typename Epilogue>
  10. struct sm89_int8_fallback_gemm {
  11. // Shared mem requirement : 61440
  12. static_assert(std::is_same<InType, int8_t>());
  13. using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
  14. using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
  15. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  16. static int32_t const MainLoopStages = 5;
  17. using Cutlass2xGemm =
  18. cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
  19. Epilogue, TileShape, WarpShape, InstructionShape, 5>;
  20. };
  21. struct sm89_int8_config_default {
  22. // M in (256, inf)
  23. using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
  24. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  25. template <typename InType, typename OutType,
  26. template <typename, typename> typename Epilogue,
  27. typename... EpilogueArgs>
  28. static void dispatch(torch::Tensor& out, torch::Tensor const& a,
  29. torch::Tensor const& b, EpilogueArgs&&... args) {
  30. static_assert(std::is_same<InType, int8_t>());
  31. TORCH_CHECK(a.dtype() == torch::kInt8);
  32. using FallbackGemm =
  33. typename sm89_int8_fallback_gemm<InType, OutType,
  34. Epilogue>::Cutlass2xGemm;
  35. uint32_t const n = out.size(1);
  36. uint32_t const np2 = next_pow_2(n);
  37. if (np2 <= 4096) {
  38. using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
  39. return aphrodite::fallback_cutlass_gemm_caller<
  40. aphrodite::cutlass_2x_gemm<
  41. cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
  42. OutType, Epilogue, TileShape, WarpShape, InstructionShape, 5>,
  43. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  44. } else if (np2 <= 8192) {
  45. using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
  46. return aphrodite::fallback_cutlass_gemm_caller<
  47. aphrodite::cutlass_2x_gemm<
  48. cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
  49. OutType, Epilogue, TileShape, WarpShape, InstructionShape, 3>,
  50. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  51. } else if (np2 <= 16384) {
  52. using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
  53. return aphrodite::fallback_cutlass_gemm_caller<
  54. aphrodite::cutlass_2x_gemm<
  55. cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
  56. OutType, Epilogue, TileShape, WarpShape, InstructionShape, 5>,
  57. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  58. } else {
  59. using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
  60. return aphrodite::fallback_cutlass_gemm_caller<
  61. aphrodite::cutlass_2x_gemm<
  62. cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
  63. OutType, Epilogue, TileShape, WarpShape, InstructionShape, 3>,
  64. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  65. }
  66. }
  67. };
  68. struct sm89_int8_config_M256 {
  69. // M in (128, 256]
  70. using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
  71. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  72. template <typename InType, typename OutType,
  73. template <typename, typename> typename Epilogue,
  74. typename... EpilogueArgs>
  75. static void dispatch(torch::Tensor& out, torch::Tensor const& a,
  76. torch::Tensor const& b, EpilogueArgs&&... args) {
  77. static_assert(std::is_same<InType, int8_t>());
  78. TORCH_CHECK(a.dtype() == torch::kInt8);
  79. using FallbackGemm =
  80. typename sm89_int8_fallback_gemm<InType, OutType,
  81. Epilogue>::Cutlass2xGemm;
  82. uint32_t const n = out.size(1);
  83. uint32_t const np2 = next_pow_2(n);
  84. if (np2 <= 4096) {
  85. using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
  86. return aphrodite::fallback_cutlass_gemm_caller<
  87. aphrodite::cutlass_2x_gemm<
  88. cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
  89. OutType, Epilogue, TileShape, WarpShape, InstructionShape, 3>,
  90. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  91. } else if (np2 <= 8192) {
  92. using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
  93. return aphrodite::fallback_cutlass_gemm_caller<
  94. aphrodite::cutlass_2x_gemm<
  95. cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
  96. OutType, Epilogue, TileShape, WarpShape, InstructionShape, 5>,
  97. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  98. } else if (np2 <= 16384) {
  99. using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
  100. return aphrodite::fallback_cutlass_gemm_caller<
  101. aphrodite::cutlass_2x_gemm<
  102. cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
  103. OutType, Epilogue, TileShape, WarpShape, InstructionShape, 3>,
  104. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  105. } else {
  106. using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
  107. return aphrodite::fallback_cutlass_gemm_caller<
  108. aphrodite::cutlass_2x_gemm<
  109. cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
  110. OutType, Epilogue, TileShape, WarpShape, InstructionShape, 5>,
  111. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  112. }
  113. }
  114. };
  115. struct sm89_int8_config_M128 {
  116. // M in (64, 128]
  117. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  118. template <typename InType, typename OutType,
  119. template <typename, typename> typename Epilogue,
  120. typename... EpilogueArgs>
  121. static void dispatch(torch::Tensor& out, torch::Tensor const& a,
  122. torch::Tensor const& b, EpilogueArgs&&... args) {
  123. static_assert(std::is_same<InType, int8_t>());
  124. TORCH_CHECK(a.dtype() == torch::kInt8);
  125. using FallbackGemm =
  126. typename sm89_int8_fallback_gemm<InType, OutType,
  127. Epilogue>::Cutlass2xGemm;
  128. uint32_t const n = out.size(1);
  129. uint32_t const np2 = next_pow_2(n);
  130. if (np2 <= 8192) {
  131. using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
  132. using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
  133. return aphrodite::fallback_cutlass_gemm_caller<
  134. aphrodite::cutlass_2x_gemm<
  135. cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
  136. OutType, Epilogue, TileShape, WarpShape, InstructionShape, 3>,
  137. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  138. } else if (np2 <= 16384) {
  139. using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
  140. using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
  141. return aphrodite::fallback_cutlass_gemm_caller<
  142. aphrodite::cutlass_2x_gemm<
  143. cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
  144. OutType, Epilogue, TileShape, WarpShape, InstructionShape, 5>,
  145. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  146. } else {
  147. using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
  148. using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
  149. return aphrodite::fallback_cutlass_gemm_caller<
  150. aphrodite::cutlass_2x_gemm<
  151. cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
  152. OutType, Epilogue, TileShape, WarpShape, InstructionShape, 5>,
  153. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  154. }
  155. }
  156. };
  157. struct sm89_int8_config_M64 {
  158. // M in (32, 64]
  159. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  160. template <typename InType, typename OutType,
  161. template <typename, typename> typename Epilogue,
  162. typename... EpilogueArgs>
  163. static void dispatch(torch::Tensor& out, torch::Tensor const& a,
  164. torch::Tensor const& b, EpilogueArgs&&... args) {
  165. static_assert(std::is_same<InType, int8_t>());
  166. TORCH_CHECK(a.dtype() == torch::kInt8);
  167. using FallbackGemm =
  168. typename sm89_int8_fallback_gemm<InType, OutType,
  169. Epilogue>::Cutlass2xGemm;
  170. uint32_t const n = out.size(1);
  171. uint32_t const np2 = next_pow_2(n);
  172. if (np2 <= 8192) {
  173. using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
  174. using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
  175. return aphrodite::fallback_cutlass_gemm_caller<
  176. aphrodite::cutlass_2x_gemm<
  177. cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
  178. OutType, Epilogue, TileShape, WarpShape, InstructionShape, 5>,
  179. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  180. } else {
  181. using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
  182. using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
  183. return aphrodite::fallback_cutlass_gemm_caller<
  184. aphrodite::cutlass_2x_gemm<
  185. cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
  186. OutType, Epilogue, TileShape, WarpShape, InstructionShape, 3>,
  187. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  188. }
  189. }
  190. };
  191. struct sm89_int8_config_M32 {
  192. // M in (16, 32]
  193. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  194. template <typename InType, typename OutType,
  195. template <typename, typename> typename Epilogue,
  196. typename... EpilogueArgs>
  197. static void dispatch(torch::Tensor& out, torch::Tensor const& a,
  198. torch::Tensor const& b, EpilogueArgs&&... args) {
  199. static_assert(std::is_same<InType, int8_t>());
  200. TORCH_CHECK(a.dtype() == torch::kInt8);
  201. using FallbackGemm =
  202. typename sm89_int8_fallback_gemm<InType, OutType,
  203. Epilogue>::Cutlass2xGemm;
  204. uint32_t const n = out.size(1);
  205. uint32_t const np2 = next_pow_2(n);
  206. if (np2 <= 8192) {
  207. using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
  208. using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
  209. return aphrodite::fallback_cutlass_gemm_caller<
  210. aphrodite::cutlass_2x_gemm<
  211. cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
  212. OutType, Epilogue, TileShape, WarpShape, InstructionShape, 5>,
  213. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  214. } else {
  215. using TileShape = cutlass::gemm::GemmShape<32, 128, 128>;
  216. using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
  217. return aphrodite::fallback_cutlass_gemm_caller<
  218. aphrodite::cutlass_2x_gemm<
  219. cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
  220. OutType, Epilogue, TileShape, WarpShape, InstructionShape, 4>,
  221. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  222. }
  223. }
  224. };
  225. struct sm89_int8_config_M16 {
  226. // M in [1, 16]
  227. using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
  228. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  229. template <typename InType, typename OutType,
  230. template <typename, typename> typename Epilogue,
  231. typename... EpilogueArgs>
  232. static void dispatch(torch::Tensor& out, torch::Tensor const& a,
  233. torch::Tensor const& b, EpilogueArgs&&... args) {
  234. static_assert(std::is_same<InType, int8_t>());
  235. TORCH_CHECK(a.dtype() == torch::kInt8);
  236. using FallbackGemm =
  237. typename sm89_int8_fallback_gemm<InType, OutType,
  238. Epilogue>::Cutlass2xGemm;
  239. uint32_t const n = out.size(1);
  240. uint32_t const np2 = next_pow_2(n);
  241. if (np2 <= 8192) {
  242. using TileShape = cutlass::gemm::GemmShape<16, 64, 128>;
  243. return aphrodite::fallback_cutlass_gemm_caller<
  244. aphrodite::cutlass_2x_gemm<
  245. cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
  246. OutType, Epilogue, TileShape, WarpShape, InstructionShape, 5>,
  247. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  248. } else {
  249. using TileShape = cutlass::gemm::GemmShape<16, 128, 128>;
  250. return aphrodite::fallback_cutlass_gemm_caller<
  251. aphrodite::cutlass_2x_gemm<
  252. cutlass::arch::Sm89, aphrodite::enable_sm89_to_sm90, InType,
  253. OutType, Epilogue, TileShape, WarpShape, InstructionShape, 4>,
  254. FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
  255. }
  256. }
  257. };
  258. template <typename InType, typename OutType,
  259. template <typename, typename> typename Epilogue,
  260. typename... EpilogueArgs>
  261. inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
  262. torch::Tensor const& a,
  263. torch::Tensor const& b,
  264. EpilogueArgs&&... args) {
  265. static_assert(std::is_same<InType, int8_t>());
  266. TORCH_CHECK(a.dtype() == torch::kInt8);
  267. TORCH_CHECK(b.dtype() == torch::kInt8);
  268. uint32_t const m = a.size(0);
  269. uint32_t const mp2 =
  270. std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
  271. if (mp2 <= 16) {
  272. // M in [1, 16]
  273. return sm89_int8_config_M16::dispatch<InType, OutType, Epilogue>(
  274. out, a, b, std::forward<EpilogueArgs>(args)...);
  275. } else if (mp2 <= 32) {
  276. // M in (16, 32]
  277. return sm89_int8_config_M32::dispatch<InType, OutType, Epilogue>(
  278. out, a, b, std::forward<EpilogueArgs>(args)...);
  279. } else if (mp2 <= 64) {
  280. // M in (32, 64]
  281. return sm89_int8_config_M64::dispatch<InType, OutType, Epilogue>(
  282. out, a, b, std::forward<EpilogueArgs>(args)...);
  283. } else if (mp2 <= 128) {
  284. // M in (64, 128]
  285. return sm89_int8_config_M128::dispatch<InType, OutType, Epilogue>(
  286. out, a, b, std::forward<EpilogueArgs>(args)...);
  287. } else if (mp2 <= 256) {
  288. // M in (128, 256]
  289. return sm89_int8_config_M256::dispatch<InType, OutType, Epilogue>(
  290. out, a, b, std::forward<EpilogueArgs>(args)...);
  291. } else {
  292. // M in (256, inf)
  293. return sm89_int8_config_default::dispatch<InType, OutType, Epilogue>(
  294. out, a, b, std::forward<EpilogueArgs>(args)...);
  295. }
  296. }
  297. } // namespace aphrodite