scaled_mm_c2x.cu 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. #include <stddef.h>
  2. #include <torch/all.h>
  3. #include "cutlass/cutlass.h"
  4. #include "scaled_mm_c2x.cuh"
  5. #include "scaled_mm_c2x_sm75_dispatch.cuh"
  6. #include "scaled_mm_c2x_sm80_dispatch.cuh"
  7. #include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
  8. #include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
  9. /*
  10. This file defines quantized GEMM operations using the CUTLASS 2.x API, for
  11. NVIDIA GPUs with SM versions prior to sm90 (Hopper).
  12. */
  13. template <template <typename, typename> typename Epilogue,
  14. typename... EpilogueArgs>
  15. void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
  16. torch::Tensor const& b,
  17. EpilogueArgs&&... epilogue_args) {
  18. TORCH_CHECK(a.dtype() == torch::kInt8);
  19. TORCH_CHECK(b.dtype() == torch::kInt8);
  20. if (out.dtype() == torch::kBFloat16) {
  21. return aphrodite::cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t,
  22. Epilogue>(
  23. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  24. } else {
  25. TORCH_CHECK(out.dtype() == torch::kFloat16);
  26. return aphrodite::cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t,
  27. Epilogue>(
  28. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  29. }
  30. }
  31. void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
  32. torch::Tensor const& b,
  33. torch::Tensor const& a_scales,
  34. torch::Tensor const& b_scales,
  35. c10::optional<torch::Tensor> const& bias) {
  36. TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  37. TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
  38. if (bias) {
  39. TORCH_CHECK(bias->dtype() == out.dtype(),
  40. "currently bias dtype must match output dtype ", out.dtype());
  41. return cutlass_scaled_mm_sm75_epilogue<aphrodite::ScaledEpilogueBias>(
  42. out, a, b, a_scales, b_scales, *bias);
  43. } else {
  44. return cutlass_scaled_mm_sm75_epilogue<aphrodite::ScaledEpilogue>(
  45. out, a, b, a_scales, b_scales);
  46. }
  47. }
  48. void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
  49. torch::Tensor const& b,
  50. torch::Tensor const& a_scales,
  51. torch::Tensor const& b_scales,
  52. torch::Tensor const& azp_adj,
  53. c10::optional<torch::Tensor> const& azp,
  54. c10::optional<torch::Tensor> const& bias) {
  55. TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  56. TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
  57. if (azp) {
  58. return cutlass_scaled_mm_sm75_epilogue<
  59. aphrodite::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales,
  60. azp_adj, *azp, bias);
  61. } else {
  62. return cutlass_scaled_mm_sm75_epilogue<aphrodite::ScaledEpilogueBiasAzp>(
  63. out, a, b, a_scales, b_scales, azp_adj, bias);
  64. }
  65. }
  66. template <template <typename, typename> typename Epilogue,
  67. typename... EpilogueArgs>
  68. void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
  69. torch::Tensor const& b,
  70. EpilogueArgs&&... epilogue_args) {
  71. TORCH_CHECK(a.dtype() == torch::kInt8);
  72. TORCH_CHECK(b.dtype() == torch::kInt8);
  73. if (out.dtype() == torch::kBFloat16) {
  74. return aphrodite::cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t,
  75. Epilogue>(
  76. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  77. } else {
  78. TORCH_CHECK(out.dtype() == torch::kFloat16);
  79. return aphrodite::cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t,
  80. Epilogue>(
  81. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  82. }
  83. }
  84. void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
  85. torch::Tensor const& b,
  86. torch::Tensor const& a_scales,
  87. torch::Tensor const& b_scales,
  88. c10::optional<torch::Tensor> const& bias) {
  89. TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  90. TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
  91. if (bias) {
  92. TORCH_CHECK(bias->dtype() == out.dtype(),
  93. "currently bias dtype must match output dtype ", out.dtype());
  94. return cutlass_scaled_mm_sm80_epilogue<aphrodite::ScaledEpilogueBias>(
  95. out, a, b, a_scales, b_scales, *bias);
  96. } else {
  97. return cutlass_scaled_mm_sm80_epilogue<aphrodite::ScaledEpilogue>(
  98. out, a, b, a_scales, b_scales);
  99. }
  100. }
  101. void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
  102. torch::Tensor const& b,
  103. torch::Tensor const& a_scales,
  104. torch::Tensor const& b_scales,
  105. torch::Tensor const& azp_adj,
  106. c10::optional<torch::Tensor> const& azp,
  107. c10::optional<torch::Tensor> const& bias) {
  108. TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  109. TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
  110. if (azp) {
  111. return cutlass_scaled_mm_sm80_epilogue<
  112. aphrodite::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales,
  113. azp_adj, *azp, bias);
  114. } else {
  115. return cutlass_scaled_mm_sm80_epilogue<aphrodite::ScaledEpilogueBiasAzp>(
  116. out, a, b, a_scales, b_scales, azp_adj, bias);
  117. }
  118. }
  119. template <template <typename, typename> typename Epilogue,
  120. typename... EpilogueArgs>
  121. void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
  122. torch::Tensor const& b,
  123. EpilogueArgs&&... epilogue_args) {
  124. if (a.dtype() == torch::kInt8) {
  125. TORCH_CHECK(b.dtype() == torch::kInt8);
  126. if (out.dtype() == torch::kBFloat16) {
  127. return aphrodite::cutlass_gemm_sm89_int8_dispatch<
  128. int8_t, cutlass::bfloat16_t, Epilogue>(
  129. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  130. } else {
  131. assert(out.dtype() == torch::kFloat16);
  132. return aphrodite::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t,
  133. Epilogue>(
  134. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  135. }
  136. } else {
  137. TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
  138. TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
  139. if (out.dtype() == torch::kBFloat16) {
  140. return aphrodite::cutlass_gemm_sm89_fp8_dispatch<
  141. cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>(
  142. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  143. } else {
  144. TORCH_CHECK(out.dtype() == torch::kFloat16);
  145. return aphrodite::cutlass_gemm_sm89_fp8_dispatch<
  146. cutlass::float_e4m3_t, cutlass::half_t, Epilogue>(
  147. out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
  148. }
  149. }
  150. }
  151. void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
  152. torch::Tensor const& b,
  153. torch::Tensor const& a_scales,
  154. torch::Tensor const& b_scales,
  155. c10::optional<torch::Tensor> const& bias) {
  156. TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  157. TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
  158. if (bias) {
  159. TORCH_CHECK(bias->dtype() == out.dtype(),
  160. "currently bias dtype must match output dtype ", out.dtype());
  161. return cutlass_scaled_mm_sm89_epilogue<aphrodite::ScaledEpilogueBias>(
  162. out, a, b, a_scales, b_scales, *bias);
  163. } else {
  164. return cutlass_scaled_mm_sm89_epilogue<aphrodite::ScaledEpilogue>(
  165. out, a, b, a_scales, b_scales);
  166. }
  167. }
  168. void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
  169. torch::Tensor const& b,
  170. torch::Tensor const& a_scales,
  171. torch::Tensor const& b_scales,
  172. torch::Tensor const& azp_adj,
  173. c10::optional<torch::Tensor> const& azp,
  174. c10::optional<torch::Tensor> const& bias) {
  175. TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
  176. TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
  177. if (azp) {
  178. return cutlass_scaled_mm_sm89_epilogue<
  179. aphrodite::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales,
  180. azp_adj, *azp, bias);
  181. } else {
  182. return cutlass_scaled_mm_sm89_epilogue<aphrodite::ScaledEpilogueBiasAzp>(
  183. out, a, b, a_scales, b_scales, azp_adj, bias);
  184. }
  185. }