#include #include #include void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, c10::optional const& bias); void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, c10::optional const& bias); void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, c10::optional const& bias); #if defined CUDA_VERSION && CUDA_VERSION >= 12000 void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, c10::optional const& bias); #endif bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) { // CUTLASS FP8 kernels need at least // CUDA 12.0 on SM90 systems (Hopper) // CUDA 12.4 on SM89 systems (Lovelace) #if defined CUDA_VERSION if (cuda_device_capability >= 90) { return CUDA_VERSION >= 12000; } else if (cuda_device_capability >= 89) { // CUTLASS Kernels have not been tuned for Ada Lovelace systems // and are slower than torch.mm. Return false unconditionally in this case. return false; // Once the CUTLASS kernels have been optimized for Lovelace systems, // use the following check: // return CUDA_VERSION >= 12040; } #endif return false; } void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, c10::optional const& bias) { int32_t major_capability; int32_t minor_capability; cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, 0); cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, 0); int32_t version_num = major_capability * 10 + minor_capability; // Checks for conformality TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && b.size(1) == c.size(1)); TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); // Check for strides and alignment TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major TORCH_CHECK(b.stride(0) == 1); // Column-major TORCH_CHECK(c.stride(0) % 16 == 0 && b.stride(1) % 16 == 0); // 16 Byte Alignment TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); if (bias) { TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && bias->dim() == 1); } at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); if (version_num >= 90) { // Hopper // Guard against compilation issues for sm90 kernels #if defined CUDA_VERSION && CUDA_VERSION >= 12000 cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias); #else cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias); #endif } else if (version_num == 89) { // Ada Lovelace cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias); } else if (version_num >= 80) { // Ampere cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias); } else { // Turing TORCH_CHECK(version_num >= 75); cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias); } }