|
@@ -6,11 +6,13 @@ from loguru import logger
|
|
|
from torch.nn import Module
|
|
|
from torch.nn.parameter import Parameter
|
|
|
|
|
|
+from aphrodite.common.utils import print_warning_once
|
|
|
from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
|
|
|
from aphrodite.modeling.utils import set_weight_attrs
|
|
|
from aphrodite.quantization.base_config import (QuantizationConfig,
|
|
|
QuantizeMethodBase)
|
|
|
-from aphrodite.common.utils import print_warning_once
|
|
|
+from aphrodite.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_statictensor import \
|
|
|
+ cutlass_scaled_mm_dq # noqa: E501
|
|
|
|
|
|
HAS_QUANTS = False
|
|
|
with suppress(ImportError):
|
|
@@ -56,6 +58,24 @@ def scaled_fp8_quant(
|
|
|
return output, scale
|
|
|
|
|
|
|
|
|
+def cutlass_fp8_supported() -> bool:
|
|
|
+ capability = torch.cuda.get_device_capability()
|
|
|
+ capability = capability[0] * 10 + capability[1]
|
|
|
+ version = torch.version.cuda
|
|
|
+ version = version[0] * 10 + version[1]
|
|
|
+
|
|
|
+ # CUTLASS FP8 kernels need at least
|
|
|
+ # CUDA 12.0 on SM90 systems (Hopper)
|
|
|
+ # CUDA 12.4 on SM89 systems (Lovelace)
|
|
|
+ gpu_is_supported = False
|
|
|
+ if capability >= 900:
|
|
|
+ gpu_is_supported = version > 120
|
|
|
+ elif capability >= 890:
|
|
|
+ gpu_is_supported = version > 124
|
|
|
+
|
|
|
+ return gpu_is_supported
|
|
|
+
|
|
|
+
|
|
|
class Fp8Config(QuantizationConfig):
|
|
|
"""Config class for FP8."""
|
|
|
|
|
@@ -99,7 +119,8 @@ class Fp8Config(QuantizationConfig):
|
|
|
|
|
|
def get_quant_method(
|
|
|
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
|
|
|
- from aphrodite.attention.layer import Attention # Avoid circular import
|
|
|
+ from aphrodite.attention.layer import \
|
|
|
+ Attention # Avoid circular import
|
|
|
|
|
|
if isinstance(layer, LinearBase):
|
|
|
return Fp8LinearMethod(self)
|
|
@@ -131,6 +152,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
|
if not HAS_QUANTS:
|
|
|
raise ImportError("Could not find the quantization kernels.")
|
|
|
self.quant_config = quant_config
|
|
|
+ self.cutlass_fp8_supported = cutlass_fp8_supported()
|
|
|
|
|
|
def _create_scale_param(
|
|
|
self,
|
|
@@ -274,22 +296,35 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
|
|
# If dynamic, layer.act_scale is None and x_scale computed from x.
|
|
|
# If static, layer.act_scale is scalar and x_scale set to act_scale.
|
|
|
- qinput, x_scale = scaled_fp8_quant(x,
|
|
|
- layer.act_scale,
|
|
|
- batch_dim_padding=17)
|
|
|
-
|
|
|
- # Fused GEMM_DQ -- note we padded the input above because
|
|
|
- # torch._scaled_mm is more performant for matrices with
|
|
|
- # batch dimension > 16. Note that this could change
|
|
|
- # in the future.
|
|
|
- output, _ = torch._scaled_mm(
|
|
|
- qinput,
|
|
|
- layer.weight,
|
|
|
- out_dtype=x.dtype,
|
|
|
- scale_a=x_scale,
|
|
|
- scale_b=layer.weight_scale,
|
|
|
- bias=bias,
|
|
|
- )
|
|
|
+ if bias is None and self.cutlass_fp8_supported:
|
|
|
+ qinput, x_scale = scaled_fp8_quant(x, layer.act_scale)
|
|
|
+
|
|
|
+ # Fused GEMM_DQ
|
|
|
+ output = cutlass_scaled_mm_dq(
|
|
|
+ qinput,
|
|
|
+ layer.weight,
|
|
|
+ out_dtype=x.dtype,
|
|
|
+ scale_a=x_scale,
|
|
|
+ scale_b=layer.weight_scale,
|
|
|
+ )
|
|
|
+
|
|
|
+ else:
|
|
|
+ qinput, x_scale = scaled_fp8_quant(x,
|
|
|
+ layer.act_scale,
|
|
|
+ batch_dim_padding=17)
|
|
|
+
|
|
|
+ # Fused GEMM_DQ -- note we padded the input above because
|
|
|
+ # torch._scaled_mm is more performant for matrices with
|
|
|
+ # batch dimension > 16. Note that this could change
|
|
|
+ # in the future.
|
|
|
+ output, _ = torch._scaled_mm(
|
|
|
+ qinput,
|
|
|
+ layer.weight,
|
|
|
+ out_dtype=x.dtype,
|
|
|
+ scale_a=x_scale,
|
|
|
+ scale_b=layer.weight_scale,
|
|
|
+ bias=bias,
|
|
|
+ )
|
|
|
|
|
|
return torch.narrow(output, 0, 0, x.shape[0])
|
|
|
|