scaled_mm_entry.cu 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. #include <cudaTypedefs.h>
  2. #include <c10/cuda/CUDAGuard.h>
  3. #include <torch/all.h>
  4. void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
  5. torch::Tensor const& b,
  6. torch::Tensor const& a_scales,
  7. torch::Tensor const& b_scales,
  8. c10::optional<torch::Tensor> const& bias);
  9. void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
  10. torch::Tensor const& b,
  11. torch::Tensor const& a_scales,
  12. torch::Tensor const& b_scales,
  13. c10::optional<torch::Tensor> const& bias);
  14. void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
  15. torch::Tensor const& b,
  16. torch::Tensor const& a_scales,
  17. torch::Tensor const& b_scales,
  18. c10::optional<torch::Tensor> const& bias);
  19. #if defined CUDA_VERSION && CUDA_VERSION >= 12000
  20. void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
  21. torch::Tensor const& b,
  22. torch::Tensor const& a_scales,
  23. torch::Tensor const& b_scales,
  24. c10::optional<torch::Tensor> const& bias);
  25. #endif
  26. bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
  27. // CUTLASS FP8 kernels need at least
  28. // CUDA 12.0 on SM90 systems (Hopper)
  29. // CUDA 12.4 on SM89 systems (Lovelace)
  30. #if defined CUDA_VERSION
  31. if (cuda_device_capability >= 90) {
  32. return CUDA_VERSION >= 12000;
  33. } else if (cuda_device_capability >= 89) {
  34. // CUTLASS Kernels have not been tuned for Ada Lovelace systems
  35. // and are slower than torch.mm. Return false unconditionally in this case.
  36. return false;
  37. // Once the CUTLASS kernels have been optimized for Lovelace systems,
  38. // use the following check:
  39. // return CUDA_VERSION >= 12040;
  40. }
  41. #endif
  42. return false;
  43. }
  44. void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
  45. torch::Tensor const& b, torch::Tensor const& a_scales,
  46. torch::Tensor const& b_scales,
  47. c10::optional<torch::Tensor> const& bias) {
  48. int32_t major_capability;
  49. int32_t minor_capability;
  50. cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
  51. 0);
  52. cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
  53. 0);
  54. int32_t version_num = major_capability * 10 + minor_capability;
  55. // Checks for conformality
  56. TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
  57. TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
  58. b.size(1) == c.size(1));
  59. TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
  60. TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
  61. // Check for strides and alignment
  62. TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
  63. TORCH_CHECK(b.stride(0) == 1); // Column-major
  64. TORCH_CHECK(c.stride(0) % 16 == 0 &&
  65. b.stride(1) % 16 == 0); // 16 Byte Alignment
  66. TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
  67. if (bias) {
  68. TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
  69. bias->dim() == 1);
  70. }
  71. at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
  72. if (version_num >= 90) {
  73. // Hopper
  74. // Guard against compilation issues for sm90 kernels
  75. #if defined CUDA_VERSION && CUDA_VERSION >= 12000
  76. cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
  77. #else
  78. cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
  79. #endif
  80. } else if (version_num == 89) {
  81. // Ada Lovelace
  82. cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
  83. } else if (version_num >= 80) {
  84. // Ampere
  85. cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
  86. } else {
  87. // Turing
  88. TORCH_CHECK(version_num >= 75);
  89. cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
  90. }
  91. }