123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 |
- from typing import Any, Dict, List, Optional
- import torch
- from torch.nn import Module
- from torch.nn.parameter import Parameter
- from aphrodite.common.utils import is_hip
- from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
- UnquantizedLinearMethod)
- from aphrodite.modeling.parameter import (ChannelQuantScaleParameter,
- ModelWeightParameter)
- from aphrodite.platforms import current_platform
- from aphrodite.quantization.base_config import (QuantizationConfig,
- QuantizeMethodBase)
- from aphrodite.quantization.fp8 import cutlass_fp8_supported
- from aphrodite.quantization.utils.marlin_utils_fp8 import (
- apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
- from aphrodite.quantization.utils.quant_utils import is_layer_skipped
- from aphrodite.quantization.utils.w8a8_utils import (
- apply_fp8_linear, normalize_e4m3fn_to_e4m3fnuz)
- class FBGEMMFp8Config(QuantizationConfig):
- """Config class for FBGEMM Fp8."""
- def __init__(self, ignore_list: List[str], input_scale_ub: float):
- self.ignore_list = ignore_list if ignore_list else []
- self.input_scale_ub = input_scale_ub
- # For GPUs that lack FP8 hardware support, we can leverage the Marlin
- # kernel for fast weight-only FP8 quantization
- capability = current_platform.get_device_capability()
- capability = capability[0] * 10 + capability[1]
- self.use_marlin = capability < 89
- @classmethod
- def get_name(cls) -> str:
- return "fbgemm_fp8"
- @classmethod
- def get_supported_act_dtypes(cls) -> List[torch.dtype]:
- return [torch.bfloat16, torch.float16]
- @classmethod
- def get_min_capability(cls) -> int:
- return 80
- @classmethod
- def get_config_filenames(cls) -> List[str]:
- return []
- @classmethod
- def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config":
- ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
- input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
- return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
- def get_quant_method(self, layer: torch.nn.Module,
- prefix: str) -> Optional["QuantizeMethodBase"]:
- if isinstance(layer, LinearBase):
- if is_layer_skipped(prefix, self.ignore_list):
- return UnquantizedLinearMethod()
- return FBGEMMFp8LinearMethod(self)
- return None
- def get_scaled_act_names(self) -> List[str]:
- return []
- class FBGEMMFp8LinearMethod(LinearMethodBase):
- def __init__(self, quant_config: FBGEMMFp8Config):
- self.quant_config = quant_config
- self.cutlass_fp8_supported = cutlass_fp8_supported()
- def create_weights(
- self,
- layer: torch.nn.Module,
- input_size_per_partition: int,
- output_partition_sizes: List[int],
- input_size: int,
- output_size: int,
- params_dtype: torch.dtype,
- **extra_weight_attrs,
- ):
- weight_loader = extra_weight_attrs.get("weight_loader")
- del input_size, output_size
- output_size_per_partition = sum(output_partition_sizes)
- layer.logical_widths = output_partition_sizes
- layer.input_size_per_partition = input_size_per_partition
- layer.output_size_per_partition = output_size_per_partition
- layer.orig_dtype = params_dtype
- # WEIGHT
- weight = ModelWeightParameter(data=torch.empty(
- output_size_per_partition,
- input_size_per_partition,
- dtype=torch.float8_e4m3fn),
- input_dim=1,
- output_dim=0,
- weight_loader=weight_loader)
- layer.register_parameter("weight", weight)
- # WEIGHT SCALE
- weight_scale = ChannelQuantScaleParameter(data=torch.empty(
- (sum(output_partition_sizes), 1), dtype=torch.float32),
- output_dim=0,
- weight_loader=weight_loader)
- weight_scale[:] = torch.finfo(torch.float32).min
- layer.register_parameter("weight_scale", weight_scale)
- # INPUT SCALE UPPER BOUND
- input_scale_ub = torch.nn.Parameter(torch.tensor(
- (self.quant_config.input_scale_ub), dtype=torch.float32),
- requires_grad=False)
- layer.input_scale_ub = input_scale_ub
- def process_weights_after_loading(self, layer: Module) -> None:
- # required by torch.compile
- layer.weight_scale = Parameter(layer.weight_scale.data,
- requires_grad=False)
- layer.weight = Parameter(layer.weight.data, requires_grad=False)
- weight = layer.weight
- if is_hip():
- weight, weight_scale, input_scale = \
- normalize_e4m3fn_to_e4m3fnuz(
- weight=weight,
- weight_scale=layer.weight_scale,
- input_scale=None)
- if input_scale is not None:
- layer.input_scale = Parameter(input_scale, requires_grad=False)
- layer.weight_scale = Parameter(weight_scale, requires_grad=False)
- layer.weight = Parameter(weight.t(), requires_grad=False)
- if self.quant_config.use_marlin:
- prepare_fp8_layer_for_marlin(layer)
- # Activations not quantized for marlin.
- del layer.input_scale_ub
- def apply(self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
- if self.quant_config.use_marlin:
- return apply_fp8_marlin_linear(
- input=x,
- weight=layer.weight,
- weight_scale=layer.weight_scale,
- workspace=layer.workspace,
- size_n=layer.output_size_per_partition,
- size_k=layer.input_size_per_partition,
- bias=bias)
- return apply_fp8_linear(
- input=x,
- weight=layer.weight,
- weight_scale=layer.weight_scale,
- input_scale=None,
- input_scale_ub=layer.input_scale_ub,
- bias=bias,
- cutlass_fp8_supported=self.cutlass_fp8_supported,
- use_per_token_if_dynamic=True)
|