Bladeren bron

chore: use cutlass kernels for fp8 if supported

AlpinDale 7 maanden geleden
bovenliggende
commit
40bc98b363

+ 4 - 2
aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py

@@ -18,14 +18,16 @@ __all__ = ["CompressedTensorsW8A8StaticTensor"]
 
 # cutlass
 def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
-                         a_scales: torch.Tensor, b_scales: torch.Tensor,
+                         scale_a: torch.Tensor, scale_b: torch.Tensor,
                          out_dtype: Type[torch.dtype]) -> torch.Tensor:
     assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
     assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
     m = a.shape[0]
     n = b.shape[1]
     out = torch.empty((m, n), dtype=out_dtype, device=a.device)
-    quant_ops.cutlass_scaled_mm_dq(out, a, b, a_scales, b_scales)
+
+    quant_ops.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b)
+
     return out
 
 

+ 53 - 18
aphrodite/quantization/fp8.py

@@ -6,11 +6,13 @@ from loguru import logger
 from torch.nn import Module
 from torch.nn.parameter import Parameter
 
+from aphrodite.common.utils import print_warning_once
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import (QuantizationConfig,
                                                 QuantizeMethodBase)
-from aphrodite.common.utils import print_warning_once
+from aphrodite.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_statictensor import \
+    cutlass_scaled_mm_dq  # noqa: E501
 
 HAS_QUANTS = False
 with suppress(ImportError):
@@ -56,6 +58,24 @@ def scaled_fp8_quant(
     return output, scale
 
 
+def cutlass_fp8_supported() -> bool:
+    capability = torch.cuda.get_device_capability()
+    capability = capability[0] * 10 + capability[1]
+    version = torch.version.cuda
+    version = version[0] * 10 + version[1]
+
+    # CUTLASS FP8 kernels need at least
+    #   CUDA 12.0 on SM90 systems (Hopper)
+    #   CUDA 12.4 on SM89 systems (Lovelace)
+    gpu_is_supported = False
+    if capability >= 900:
+        gpu_is_supported = version > 120
+    elif capability >= 890:
+        gpu_is_supported = version > 124
+
+    return gpu_is_supported
+
+
 class Fp8Config(QuantizationConfig):
     """Config class for FP8."""
 
@@ -99,7 +119,8 @@ class Fp8Config(QuantizationConfig):
 
     def get_quant_method(
             self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
-        from aphrodite.attention.layer import Attention  # Avoid circular import
+        from aphrodite.attention.layer import \
+            Attention  # Avoid circular import
 
         if isinstance(layer, LinearBase):
             return Fp8LinearMethod(self)
@@ -131,6 +152,7 @@ class Fp8LinearMethod(LinearMethodBase):
         if not HAS_QUANTS:
             raise ImportError("Could not find the quantization kernels.")
         self.quant_config = quant_config
+        self.cutlass_fp8_supported = cutlass_fp8_supported()
 
     def _create_scale_param(
         self,
@@ -274,22 +296,35 @@ class Fp8LinearMethod(LinearMethodBase):
         # ops.scaled_fp8_quant supports both dynamic and static quant.
         #   If dynamic, layer.act_scale is None and x_scale computed from x.
         #   If static,  layer.act_scale is scalar and x_scale set to act_scale.
-        qinput, x_scale = scaled_fp8_quant(x,
-                                           layer.act_scale,
-                                           batch_dim_padding=17)
-
-        # Fused GEMM_DQ -- note we padded the input above because
-        # torch._scaled_mm is more performant for matrices with
-        # batch dimension > 16. Note that this could change
-        # in the future.
-        output, _ = torch._scaled_mm(
-            qinput,
-            layer.weight,
-            out_dtype=x.dtype,
-            scale_a=x_scale,
-            scale_b=layer.weight_scale,
-            bias=bias,
-        )
+        if bias is None and self.cutlass_fp8_supported:
+            qinput, x_scale = scaled_fp8_quant(x, layer.act_scale)
+
+            # Fused GEMM_DQ
+            output = cutlass_scaled_mm_dq(
+                qinput,
+                layer.weight,
+                out_dtype=x.dtype,
+                scale_a=x_scale,
+                scale_b=layer.weight_scale,
+            )
+
+        else:
+            qinput, x_scale = scaled_fp8_quant(x,
+                                               layer.act_scale,
+                                               batch_dim_padding=17)
+
+            # Fused GEMM_DQ -- note we padded the input above because
+            # torch._scaled_mm is more performant for matrices with
+            # batch dimension > 16. Note that this could change
+            # in the future.
+            output, _ = torch._scaled_mm(
+                qinput,
+                layer.weight,
+                out_dtype=x.dtype,
+                scale_a=x_scale,
+                scale_b=layer.weight_scale,
+                bias=bias,
+            )
 
         return torch.narrow(output, 0, 0, x.shape[0])