123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- from contextlib import suppress
- from typing import Any, Dict, List, Optional
- import torch
- from torch.nn import Module
- from torch.nn.parameter import Parameter
- from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
- from aphrodite.modeling.utils import set_weight_attrs
- from aphrodite.quantization.base_config import (QuantizationConfig,
- QuantizeMethodBase)
- HAS_QUANTS = False
- with suppress(ImportError):
- from aphrodite._quant_C import quant_ops as ops
- HAS_QUANTS = True
- class Fp8Config(QuantizationConfig):
- """Config class for FP8."""
- def __init__(
- self,
- activation_scheme: str = "dynamic",
- ) -> None:
- self.activation_scheme = activation_scheme
- @classmethod
- def get_name(cls) -> str:
- return "fp8"
- @classmethod
- def get_supported_act_dtypes(cls) -> List[torch.dtype]:
- return [torch.bfloat16, torch.half]
- @classmethod
- def get_min_capability(cls) -> int:
- return 89
- @classmethod
- def get_config_filenames(cls) -> List[str]:
- return []
- @classmethod
- def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
- activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
- return cls(activation_scheme)
- def get_quant_method(
- self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
- if isinstance(layer, LinearBase):
- return Fp8LinearMethod(self)
- return None
- def get_scaled_act_names(self) -> List[str]:
- return []
- class Fp8LinearMethod(LinearMethodBase):
- """Linear method for FP8.
- We now support common FP16/BF16 model checkpoints ONLY. The weight
- scaling factor will be initialized after the model weights are loaded.
- Limitations:
- 1. Only support per-tensor quantization due to torch._scaled_mm support.
- 2. Only support float8_e4m3fn data type due to the limitation of
- torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
-
- Args:
- quant_config: The quantization config.
- """
- def __init__(self, quant_config: Fp8Config):
- if not HAS_QUANTS:
- raise ImportError("Could not find the quantization kernels.")
- self.quant_config = quant_config
- def create_weights(
- self,
- layer: torch.nn.Module,
- input_size_per_partition: int,
- output_partition_sizes: List[int],
- input_size: int,
- output_size: int,
- params_dtype: torch.dtype,
- **extra_weight_attrs,
- ):
- output_size_per_partition = sum(output_partition_sizes)
- weight = Parameter(torch.empty(output_size_per_partition,
- input_size_per_partition,
- dtype=params_dtype),
- requires_grad=False)
- layer.register_parameter("weight", weight)
- set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
- set_weight_attrs(weight, extra_weight_attrs)
- w_scale = Parameter(
- torch.empty(1, dtype=torch.float32),
- requires_grad=False,
- )
- layer.register_parameter("weight_scaling_factor", w_scale)
- def process_weights_after_loading(self, layer: Module) -> None:
- # Although the linear_method is propagated to all layers,
- # only linear layers invoke "create_weights". So we check
- # whether "weight_scaling_facor" is registered to determine
- # whether the layer is a linear layer that requires quantization.
- if not hasattr(layer, "weight_scaling_factor"):
- return
- qweight, weight_scale = ops.scaled_fp8_quant(layer.weight)
- # torch._scaled_mm requires column-major in the second
- # input (weight), so we transpose the quantized weight.
- layer.weight = Parameter(qweight.t(), requires_grad=False)
- layer.weight_scaling_factor.data.copy_(weight_scale)
- def apply(self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
- qinput, x_scale = ops.scaled_fp8_quant(x)
- output, _ = torch._scaled_mm(
- qinput,
- layer.weight,
- out_dtype=x.dtype,
- scale_a=x_scale,
- scale_b=layer.weight_scaling_factor,
- bias=bias,
- )
- return output
|