from contextlib import suppress from typing import Any, Dict, List, Optional import torch from torch.nn import Module from torch.nn.parameter import Parameter from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase from aphrodite.modeling.utils import set_weight_attrs from aphrodite.quantization.base_config import (QuantizationConfig, QuantizeMethodBase) HAS_QUANTS = False with suppress(ImportError): from aphrodite._quant_C import quant_ops as ops HAS_QUANTS = True class Fp8Config(QuantizationConfig): """Config class for FP8.""" def __init__( self, activation_scheme: str = "dynamic", ) -> None: self.activation_scheme = activation_scheme @classmethod def get_name(cls) -> str: return "fp8" @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: return [torch.bfloat16, torch.half] @classmethod def get_min_capability(cls) -> int: return 89 @classmethod def get_config_filenames(cls) -> List[str]: return [] @classmethod def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) return cls(activation_scheme) def get_quant_method( self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): return Fp8LinearMethod(self) return None def get_scaled_act_names(self) -> List[str]: return [] class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. We now support common FP16/BF16 model checkpoints ONLY. The weight scaling factor will be initialized after the model weights are loaded. Limitations: 1. Only support per-tensor quantization due to torch._scaled_mm support. 2. Only support float8_e4m3fn data type due to the limitation of torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) Args: quant_config: The quantization config. """ def __init__(self, quant_config: Fp8Config): if not HAS_QUANTS: raise ImportError("Could not find the quantization kernels.") self.quant_config = quant_config def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): output_size_per_partition = sum(output_partition_sizes) weight = Parameter(torch.empty(output_size_per_partition, input_size_per_partition, dtype=params_dtype), requires_grad=False) layer.register_parameter("weight", weight) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) set_weight_attrs(weight, extra_weight_attrs) w_scale = Parameter( torch.empty(1, dtype=torch.float32), requires_grad=False, ) layer.register_parameter("weight_scaling_factor", w_scale) def process_weights_after_loading(self, layer: Module) -> None: # Although the linear_method is propagated to all layers, # only linear layers invoke "create_weights". So we check # whether "weight_scaling_facor" is registered to determine # whether the layer is a linear layer that requires quantization. if not hasattr(layer, "weight_scaling_factor"): return qweight, weight_scale = ops.scaled_fp8_quant(layer.weight) # torch._scaled_mm requires column-major in the second # input (weight), so we transpose the quantized weight. layer.weight = Parameter(qweight.t(), requires_grad=False) layer.weight_scaling_factor.data.copy_(weight_scale) def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: qinput, x_scale = ops.scaled_fp8_quant(x) output, _ = torch._scaled_mm( qinput, layer.weight, out_dtype=x.dtype, scale_a=x_scale, scale_b=layer.weight_scaling_factor, bias=bias, ) return output