Преглед изворни кода

fix: cuda version check for fp8 support in the cutlass kernels

AlpinDale пре 7 месеци
родитељ
комит
cd9ed8623b

+ 4 - 0
aphrodite/_custom_ops.py

@@ -209,6 +209,10 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
 
 
 # cutlass
+def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
+    return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
+
+
 def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
                       scale_b: torch.Tensor,
                       out_dtype: Type[torch.dtype]) -> torch.Tensor:

+ 1 - 0
aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py

@@ -1,3 +1,4 @@
+from typing import Callable, List, Tuple, Union
 
 import torch
 from torch.nn import Parameter

+ 1 - 13
aphrodite/quantization/fp8.py

@@ -18,19 +18,7 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
 def cutlass_fp8_supported() -> bool:
     capability = torch.cuda.get_device_capability()
     capability = capability[0] * 10 + capability[1]
-    major, minor = torch.version.cuda.split(".")
-    version = int(major) * 10 + int(minor)
-
-    # CUTLASS FP8 kernels need at least
-    #   CUDA 12.0 on SM90 systems (Hopper)
-    #   CUDA 12.4 on SM89 systems (Lovelace)
-    gpu_is_supported = False
-    if capability >= 90:
-        gpu_is_supported = version > 120
-    elif capability >= 89:
-        gpu_is_supported = version > 124
-
-    return gpu_is_supported
+    return ops.cutlass_scaled_mm_supports_fp8(capability)
 
 
 class Fp8Config(QuantizationConfig):

+ 21 - 20
kernels/quantization/cutlass_w8a8/scaled_mm_c3x.cu

@@ -465,29 +465,30 @@ void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
                                              ScaledEpilogue>(
           out, a, b, a_scales, b_scales);
     }
-    } else {
-      TORCH_CHECK(out.dtype() == torch::kFloat16);
-
-      return cutlass_gemm_caller<
-          cutlass_3x_gemm<int8_t, cutlass::half_t, ScaledEpilogue, TileShape,
-                          ClusterShape, KernelSchedule, EpilogueSchedule>>(
-          out, a, b, a_scales, b_scales);
-    }
   } else {
-    TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
-    TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
+    TORCH_CHECK(out.dtype() == torch::kFloat16);
 
-    if (out.dtype() == torch::kBFloat16) {
-      return cutlass_gemm_sm90_fp8_dispatch<
-          cutlass::float_e4m3_t, cutlass::bfloat16_t, ScaledEpilogue>(
-          out, a, b, a_scales, b_scales);
-    } else {
-      TORCH_CHECK(out.dtype() == torch::kFloat16);
-      return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
-                                            cutlass::half_t, ScaledEpilogue>(
-          out, a, b, a_scales, b_scales);
-    }
+    return cutlass_gemm_caller<
+        cutlass_3x_gemm<int8_t, cutlass::half_t, ScaledEpilogue, TileShape,
+                        ClusterShape, KernelSchedule, EpilogueSchedule>>(
+        out, a, b, a_scales, b_scales);
   }
 }
+else {
+  TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
+  TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
+
+  if (out.dtype() == torch::kBFloat16) {
+    return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
+                                          cutlass::bfloat16_t, ScaledEpilogue>(
+        out, a, b, a_scales, b_scales);
+  } else {
+    TORCH_CHECK(out.dtype() == torch::kFloat16);
+    return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
+                                          cutlass::half_t, ScaledEpilogue>(
+        out, a, b, a_scales, b_scales);
+  }
+}
+}
 
 #endif

+ 16 - 0
kernels/quantization/cutlass_w8a8/scaled_mm_entry.cu

@@ -25,6 +25,22 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
                             torch::Tensor const& b_scales);
 #endif
 
+bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
+  // CUTLASS FP8 kernels need at least
+  //   CUDA 12.0 on SM90 systems (Hopper)
+  //   CUDA 12.4 on SM89 systems (Lovelace)
+
+#if defined CUDA_VERSION
+  if (cuda_device_capability >= 90) {
+    return CUDA_VERSION >= 12000;
+  } else if (cuda_device_capability >= 89) {
+    return CUDA_VERSION >= 12040;
+  }
+#endif
+
+  return false;
+}
+
 void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
                        torch::Tensor const& b, torch::Tensor const& a_scales,
                        torch::Tensor const& b_scales) {

+ 2 - 0
kernels/quantization/quant_ops.h

@@ -85,6 +85,8 @@ at::Tensor e8p_mm_origorder(const at::Tensor& A, const at::Tensor& B,
 void decompress_e8p_origorder(torch::Tensor YIs, torch::Tensor CB,
                               torch::Tensor& Y);
 
+bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
+
 void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
                        torch::Tensor const& b, torch::Tensor const& a_scales,
                        torch::Tensor const& b_scales);

+ 6 - 0
kernels/torch_bindings.cpp

@@ -142,6 +142,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
       "                  Tensor b_scales) -> ()");
   ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
 
+  // Check if cutlass scaled_mm is supported for CUDA devices of the given
+  // capability
+  ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
+  ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA,
+           &cutlass_scaled_mm_supports_fp8);
+
   // QuIP# GEMV
   ops.def("quip_gemv", &e8p_mm_origorder);
   ops.impl("quip_gemv", torch::kCUDA, &e8p_mm_origorder);