scaled_mm_entry.cu 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. #include <cudaTypedefs.h>
  2. #include <c10/cuda/CUDAGuard.h>
  3. #include <torch/all.h>
  4. void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
  5. torch::Tensor const& b,
  6. torch::Tensor const& a_scales,
  7. torch::Tensor const& b_scales);
  8. void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
  9. torch::Tensor const& b,
  10. torch::Tensor const& a_scales,
  11. torch::Tensor const& b_scales);
  12. void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
  13. torch::Tensor const& b,
  14. torch::Tensor const& a_scales,
  15. torch::Tensor const& b_scales);
  16. #if defined CUDA_VERSION && CUDA_VERSION >= 12000
  17. void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
  18. torch::Tensor const& b,
  19. torch::Tensor const& a_scales,
  20. torch::Tensor const& b_scales);
  21. #endif
  22. bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
  23. // CUTLASS FP8 kernels need at least
  24. // CUDA 12.0 on SM90 systems (Hopper)
  25. // CUDA 12.4 on SM89 systems (Lovelace)
  26. #if defined CUDA_VERSION
  27. if (cuda_device_capability >= 90) {
  28. return CUDA_VERSION >= 12000;
  29. } else if (cuda_device_capability >= 89) {
  30. return CUDA_VERSION >= 12040;
  31. }
  32. #endif
  33. return false;
  34. }
  35. void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
  36. torch::Tensor const& b, torch::Tensor const& a_scales,
  37. torch::Tensor const& b_scales) {
  38. int32_t major_capability;
  39. int32_t minor_capability;
  40. cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
  41. 0);
  42. cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
  43. 0);
  44. int32_t version_num = major_capability * 10 + minor_capability;
  45. // Checks for conformality
  46. TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
  47. TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
  48. b.size(1) == c.size(1));
  49. TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
  50. TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
  51. // Check for strides and alignment
  52. TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
  53. TORCH_CHECK(b.stride(0) == 1); // Column-major
  54. TORCH_CHECK(c.stride(0) % 16 == 0 &&
  55. b.stride(1) % 16 == 0); // 16 Byte Alignment
  56. TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
  57. at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
  58. if (version_num >= 90) {
  59. // Hopper
  60. // Guard against compilation issues for sm90 kernels
  61. #if defined CUDA_VERSION && CUDA_VERSION >= 12000
  62. cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales);
  63. #else
  64. cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales);
  65. #endif
  66. } else if (version_num == 89) {
  67. // Ada Lovelace
  68. cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales);
  69. } else if (version_num >= 80) {
  70. // Ampere
  71. cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales);
  72. } else {
  73. // Turing
  74. TORCH_CHECK(version_num >= 75);
  75. cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales);
  76. }
  77. }