123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152 |
- from typing import Any, Dict, List, Optional
- import torch
- from torch.nn import Module
- from torch.nn.parameter import Parameter
- from aphrodite.modeling.layers.linear import (LinearMethodBase,
- set_weight_attrs)
- from aphrodite.quantization.base_config import QuantizationConfig
- class FP8Config(QuantizationConfig):
- """Config class for FP8."""
- @classmethod
- def get_name(cls) -> str:
- return "fp8"
- @classmethod
- def get_supported_act_dtypes(cls) -> List[torch.dtype]:
- return [torch.bfloat16, torch.half]
- @classmethod
- def get_min_capability(cls) -> int:
- return 89
-
- @classmethod
- def get_config_filenames(cls) -> List[str]:
- return []
- @classmethod
- def from_config(cls, config: Dict[str, Any]) -> "FP8Config":
- return cls()
- def get_linear_method(self) -> "Fp8LinearMethod":
- return Fp8LinearMethod(self)
- def get_scaled_act_names(self) -> List[str]:
- return []
-
- def merge_weight(self) -> bool:
- return True
- def rope_style(self) -> Optional[bool]:
- return None
- def quant_vocab(self) -> List[bool]:
- return [False, False]
- def support_fused_moe(self) -> bool:
- return True
- class Fp8LinearMethod(LinearMethodBase):
- """Linear method for FP8.
- We now support common FP16/BF16 model checkpoints ONLY. The weight
- scaling factor will be initialized after the model weights are loaded.
- Limitations:
- 1. Only support per-tensor quantization due to torch._scaled_mm support.
- 2. Only support float8_e4m3fn data type due to the limitation of
- torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
-
- Args:
- quant_config: The quantization config.
- """
- def __init__(self, quant_config: FP8Config):
- self.quant_config = quant_config
- def create_weights(
- self,
- layer: torch.nn.Module,
- input_size_per_partition: int,
- output_partition_sizes: List[int],
- input_size: int,
- output_size: int,
- params_dtype: torch.dtype,
- **extra_weight_attrs,
- ):
- output_size_per_partition = sum(output_partition_sizes)
- weight = Parameter(torch.empty(output_size_per_partition,
- input_size_per_partition,
- dtype=params_dtype),
- requires_grad=False)
- layer.register_parameter("weight", weight)
- set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
- set_weight_attrs(weight, extra_weight_attrs)
- w_scale = Parameter(
- torch.empty(1, dtype=torch.float32),
- requires_grad=False,
- )
- layer.register_parameter("weight_scaling_factor", w_scale)
- def process_weights_after_loading(self, layer: Module) -> None:
- # Although the linear_method is propagated to all layers,
- # only linear layers invoke "create_weights". So we check
- # whether "weight_scaling_facor" is registered to determine
- # whether the layer is a linear layer that requires quantization.
- if not hasattr(layer, "weight_scaling_factor"):
- return
-
- qweight, weight_scale = per_tensor_quantize(layer.weight)
- # torch._scaled_mm requires column-major in the second
- # input (weight), so we transpose the quantized weight.
- layer.weight = Parameter(qweight.t(), requires_grad=False)
- layer.weight_scaling_factor.data.copy_(weight_scale)
- def apply_weights(self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
- qinput, x_scale = per_tensor_quantize(x)
- output, _ = torch._scaled_mm(
- qinput,
- layer.weight,
- out_dtype=x.dtype,
- scale_a=x_scale,
- scale_b=layer.weight_scaling_factor,
- bias=bias,
- )
- return output
-
- def apply_moe_weights(self, w1: Dict[str,
- torch.Tensor], w2: Dict[str,
- torch.Tensor],
- x: torch.Tensor, gating_output: torch.Tensor,
- topk: int, renormalize: bool) -> torch.Tensor:
- raise NotImplementedError
- def per_tensor_quantize(tensor: torch.Tensor) -> tuple[torch.Tensor, float]:
- """Quantize a tensor using per-tensor static scaling factor.
- Args:
- tensor: The input tensor.
- """
- finfo = torch.finfo(torch.float8_e4m3fn)
- # Calculate the scale as dtype max divided by absmax.
- # Since .abs() creates a new tensor, we use aminmax to get
- # the min and max first and then calculate the absmax.
- min_val, max_val = tensor.aminmax()
- amax = min_val.abs().max(max_val.abs())
- scale = finfo.max / amax.clamp(min=1e-12)
- # scale and clamp the tensor to bring it to
- # the representative range of float8 data type
- # (as default cast is unsaturated)
- qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
- # Return both float8 data and the inverse scale (as float),
- # as both required as inputs to torch._scaled_mm
- qweight = qweight.to(torch.float8_e4m3fn)
- scale = scale.float().reciprocal()
- return qweight, scale
|