scaled_mm_entry.cu 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. #include <cudaTypedefs.h>
  2. #include <c10/cuda/CUDAGuard.h>
  3. #include <torch/all.h>
  4. void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
  5. torch::Tensor const& b,
  6. torch::Tensor const& a_scales,
  7. torch::Tensor const& b_scales,
  8. c10::optional<torch::Tensor> const& bias);
  9. void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
  10. torch::Tensor const& b,
  11. torch::Tensor const& a_scales,
  12. torch::Tensor const& b_scales,
  13. c10::optional<torch::Tensor> const& bias);
  14. void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
  15. torch::Tensor const& b,
  16. torch::Tensor const& a_scales,
  17. torch::Tensor const& b_scales,
  18. c10::optional<torch::Tensor> const& bias);
  19. #if defined CUDA_VERSION && CUDA_VERSION >= 12000
  20. void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
  21. torch::Tensor const& b,
  22. torch::Tensor const& a_scales,
  23. torch::Tensor const& b_scales,
  24. c10::optional<torch::Tensor> const& bias);
  25. #endif
  26. void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
  27. torch::Tensor const& b,
  28. torch::Tensor const& a_scales,
  29. torch::Tensor const& b_scales,
  30. torch::Tensor const& azp_adj,
  31. c10::optional<torch::Tensor> const& azp,
  32. c10::optional<torch::Tensor> const& bias);
  33. void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
  34. torch::Tensor const& b,
  35. torch::Tensor const& a_scales,
  36. torch::Tensor const& b_scales,
  37. torch::Tensor const& azp_adj,
  38. c10::optional<torch::Tensor> const& azp,
  39. c10::optional<torch::Tensor> const& bias);
  40. void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
  41. torch::Tensor const& b,
  42. torch::Tensor const& a_scales,
  43. torch::Tensor const& b_scales,
  44. torch::Tensor const& azp_adj,
  45. c10::optional<torch::Tensor> const& azp,
  46. c10::optional<torch::Tensor> const& bias);
  47. #if defined CUDA_VERSION && CUDA_VERSION >= 12000
  48. void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
  49. torch::Tensor const& b,
  50. torch::Tensor const& a_scales,
  51. torch::Tensor const& b_scales,
  52. torch::Tensor const& azp_adj,
  53. c10::optional<torch::Tensor> const& azp,
  54. c10::optional<torch::Tensor> const& bias);
  55. #endif
  56. bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
  57. // CUTLASS FP8 kernels need at least
  58. // CUDA 12.0 on SM90 systems (Hopper)
  59. // CUDA 12.4 on SM89 systems (Lovelace)
  60. #if defined CUDA_VERSION
  61. if (cuda_device_capability >= 90) {
  62. return CUDA_VERSION >= 12000;
  63. } else if (cuda_device_capability >= 89) {
  64. return CUDA_VERSION >= 12040;
  65. }
  66. #endif
  67. return false;
  68. }
  69. int32_t get_sm_version_num() {
  70. int32_t major_capability, minor_capability;
  71. cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
  72. 0);
  73. cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
  74. 0);
  75. int32_t version_num = major_capability * 10 + minor_capability;
  76. return version_num;
  77. }
  78. void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
  79. torch::Tensor const& b, torch::Tensor const& a_scales,
  80. torch::Tensor const& b_scales,
  81. c10::optional<torch::Tensor> const& bias) {
  82. // Checks for conformality
  83. TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
  84. TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
  85. b.size(1) == c.size(1));
  86. TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
  87. TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
  88. // Check for strides and alignment
  89. TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
  90. TORCH_CHECK(b.stride(0) == 1); // Column-major
  91. TORCH_CHECK(c.stride(0) % 16 == 0 &&
  92. b.stride(1) % 16 == 0); // 16 Byte Alignment
  93. TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
  94. if (bias) {
  95. TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
  96. bias->dim() == 1);
  97. }
  98. at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
  99. int32_t version_num = get_sm_version_num();
  100. if (version_num >= 90) {
  101. // Hopper
  102. // Guard against compilation issues for sm90 kernels
  103. #if defined CUDA_VERSION && CUDA_VERSION >= 12000
  104. cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
  105. #else
  106. cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
  107. #endif
  108. } else if (version_num == 89) {
  109. // Ada Lovelace
  110. cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
  111. } else if (version_num >= 80) {
  112. // Ampere
  113. cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
  114. } else {
  115. // Turing
  116. TORCH_CHECK(version_num >= 75);
  117. cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
  118. }
  119. }
  120. void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
  121. torch::Tensor const& b,
  122. torch::Tensor const& a_scales,
  123. torch::Tensor const& b_scales,
  124. torch::Tensor const& azp_adj,
  125. c10::optional<torch::Tensor> const& azp,
  126. c10::optional<torch::Tensor> const& bias) {
  127. // Checks for conformality
  128. TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
  129. TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
  130. b.size(1) == c.size(1));
  131. TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
  132. TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
  133. // Check for strides and alignment
  134. TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
  135. TORCH_CHECK(b.stride(0) == 1); // Column-major
  136. TORCH_CHECK(c.stride(0) % 16 == 0 &&
  137. b.stride(1) % 16 == 0); // 16 Byte Alignment
  138. TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
  139. // bias, azp, azp_adj are all 1d
  140. // bias and azp_adj have n elements, azp has m elements
  141. if (bias) {
  142. TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
  143. }
  144. if (azp) {
  145. TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
  146. }
  147. TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
  148. // azp & bias types
  149. TORCH_CHECK(azp_adj.dtype() == torch::kInt32);
  150. TORCH_CHECK(!azp || azp->dtype() == torch::kInt32);
  151. TORCH_CHECK(!bias || bias->dtype() == c.dtype(),
  152. "currently bias dtype must match output dtype ", c.dtype());
  153. at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
  154. int32_t version_num = get_sm_version_num();
  155. if (version_num >= 90) {
  156. // Hopper
  157. // Guard against compilation issues for sm90 kernels
  158. #if defined CUDA_VERSION && CUDA_VERSION >= 12000
  159. cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
  160. #else
  161. cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
  162. #endif
  163. } else if (version_num == 89) {
  164. // Ada Lovelace
  165. cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
  166. } else if (version_num >= 80) {
  167. // Ampere
  168. cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
  169. } else {
  170. // Turing
  171. TORCH_CHECK(version_num >= 75);
  172. cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
  173. }
  174. }