123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512 |
- from typing import Any, Callable, Dict, List, Optional
- import torch
- from loguru import logger
- from torch.nn import Module
- from torch.nn.parameter import Parameter
- import aphrodite.common.envs as envs
- from aphrodite import _custom_ops as ops
- from aphrodite.common.utils import is_hip, print_warning_once
- from aphrodite.modeling.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
- FusedMoeWeightScaleSupported)
- from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
- UnquantizedLinearMethod)
- from aphrodite.modeling.parameter import (ModelWeightParameter,
- PerTensorScaleParameter)
- from aphrodite.modeling.utils import set_weight_attrs
- from aphrodite.platforms import current_platform
- from aphrodite.quantization.base_config import (QuantizationConfig,
- QuantizeMethodBase)
- from aphrodite.quantization.kv_cache import BaseKVCacheMethod
- from aphrodite.quantization.utils.marlin_utils_fp8 import (
- apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
- from aphrodite.quantization.utils.quant_utils import is_layer_skipped
- from aphrodite.quantization.utils.w8a8_utils import (
- all_close_1d, apply_fp8_linear, convert_to_channelwise,
- cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
- requantize_with_max_scale)
- ACTIVATION_SCHEMES = ["static", "dynamic"]
- APHRODITE_TEST_FORCE_FP8_MARLIN = envs.APHRODITE_TEST_FORCE_FP8_MARLIN
- class Fp8Config(QuantizationConfig):
- """Config class for FP8."""
- def __init__(
- self,
- is_checkpoint_fp8_serialized: bool = False,
- activation_scheme: str = "dynamic",
- ignored_layers: Optional[List[str]] = None,
- ) -> None:
- self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
- if is_checkpoint_fp8_serialized:
- logger.warning("Detected fp8 checkpoint. Please note that the "
- "format is experimental and subject to change.")
- if activation_scheme not in ACTIVATION_SCHEMES:
- raise ValueError(
- f"Unsupported activation scheme {activation_scheme}")
- self.activation_scheme = activation_scheme
- self.ignored_layers = ignored_layers or []
- @classmethod
- def get_name(cls) -> str:
- return "fp8"
- @classmethod
- def get_supported_act_dtypes(cls) -> List[torch.dtype]:
- return [torch.bfloat16, torch.half]
- @classmethod
- def get_min_capability(cls) -> int:
- return 80
- @classmethod
- def get_config_filenames(cls) -> List[str]:
- return []
- @classmethod
- def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
- quant_method = cls.get_from_keys(config, ["quant_method"])
- is_checkpoint_fp8_serialized = ("fp8" in quant_method)
- activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
- ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
- return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
- activation_scheme=activation_scheme,
- ignored_layers=ignored_layers)
- def get_quant_method(self, layer: torch.nn.Module,
- prefix: str) -> Optional["QuantizeMethodBase"]:
- from aphrodite.attention.layer import (
- Attention) # Avoid circular import
- if isinstance(layer, LinearBase):
- if is_layer_skipped(prefix, self.ignored_layers):
- return UnquantizedLinearMethod()
- return Fp8LinearMethod(self)
- elif isinstance(layer, FusedMoE):
- return Fp8MoEMethod(self)
- elif isinstance(layer, Attention):
- return Fp8KVCacheMethod(self)
- return None
- def get_scaled_act_names(self) -> List[str]:
- return []
- class Fp8LinearMethod(LinearMethodBase):
- """Linear method for FP8.
- Supports loading FP8 checkpoints with static weight scale and
- dynamic/static activation scale.
- Also supports loading quantized FP16/BF16 model checkpoints with dynamic
- activation scaling. The weight scaling factor will be initialized after
- the model weights are loaded.
- Limitations:
- 1. Only support per-tensor quantization due to torch._scaled_mm support.
- 2. Only support float8_e4m3fn data type due to the limitation of
- torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
- Args:
- quant_config: The quantization config.
- """
- def __init__(self, quant_config: Fp8Config):
- self.quant_config = quant_config
- self.cutlass_fp8_supported = cutlass_fp8_supported()
- # For GPUs that lack FP8 hardware support, we can leverage the Marlin
- # kernel for fast weight-only FP8 quantization
- capability = current_platform.get_device_capability()
- capability = capability[0] * 10 + capability[1]
- self.use_marlin = capability < 89 or APHRODITE_TEST_FORCE_FP8_MARLIN
- # Disable marlin for rocm
- if is_hip():
- self.use_marlin = False
- def create_weights(
- self,
- layer: torch.nn.Module,
- input_size_per_partition: int,
- output_partition_sizes: List[int],
- input_size: int,
- output_size: int,
- params_dtype: torch.dtype,
- **extra_weight_attrs,
- ):
- del input_size, output_size
- output_size_per_partition = sum(output_partition_sizes)
- weight_loader = extra_weight_attrs.get("weight_loader")
- layer.logical_widths = output_partition_sizes
- layer.input_size_per_partition = input_size_per_partition
- layer.output_size_per_partition = output_size_per_partition
- layer.orig_dtype = params_dtype
- # WEIGHT
- weight_dtype = (torch.float8_e4m3fn
- if self.quant_config.is_checkpoint_fp8_serialized else
- params_dtype)
- weight = ModelWeightParameter(data=torch.empty(
- output_size_per_partition,
- input_size_per_partition,
- dtype=weight_dtype),
- input_dim=1,
- output_dim=0,
- weight_loader=weight_loader)
- layer.register_parameter("weight", weight)
- # If checkpoint is serialized fp8, load them.
- # Otherwise, wait until process_weights_after_loading.
- if self.quant_config.is_checkpoint_fp8_serialized:
- # WEIGHT SCALE
- scale = PerTensorScaleParameter(data=torch.empty(
- len(output_partition_sizes), dtype=torch.float32),
- weight_loader=weight_loader)
- scale[:] = torch.finfo(torch.float32).min
- layer.register_parameter("weight_scale", scale)
- # INPUT ACTIVATION SCALE
- if self.quant_config.activation_scheme == "static":
- scale = PerTensorScaleParameter(data=torch.empty(
- len(output_partition_sizes), dtype=torch.float32),
- weight_loader=weight_loader)
- scale[:] = torch.finfo(torch.float32).min
- layer.register_parameter("input_scale", scale)
- else:
- layer.register_parameter("input_scale", None)
- def process_weights_after_loading(self, layer: Module) -> None:
- layer.weight = torch.nn.Parameter(layer.weight.data,
- requires_grad=False)
- # If checkpoint not serialized fp8, quantize the weights.
- if not self.quant_config.is_checkpoint_fp8_serialized:
- qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
- scale=None)
- # If using marlin (w8a16), kernel uses channelwise weights,
- # so extend the weight scales to be channelwise.
- if self.use_marlin:
- assert weight_scale.numel() == 1
- weight_scale = convert_to_channelwise(
- weight_scale.expand(len(layer.logical_widths)),
- layer.logical_widths)
- # Update the layer with the new values.
- layer.weight = Parameter(qweight.t(), requires_grad=False)
- layer.weight_scale = Parameter(weight_scale, requires_grad=False)
- layer.input_scale = None
- # If checkpoint is fp8, handle that there are N scales for N
- # shards in a fused module
- else:
- layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
- requires_grad=False)
- if self.quant_config.activation_scheme == "static":
- layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
- requires_grad=False)
- # If using marlin (w8a16), kernel uses channelwise weights,
- # so extend the weight scales to be channelwise.
- if self.use_marlin:
- weight = layer.weight
- weight_scale = convert_to_channelwise(layer.weight_scale,
- layer.logical_widths)
- # If using w8a8, torch._scaled_mm needs per tensor, so
- # requantize the logical shards as a single weight.
- else:
- # Dequant -> Quant with max scale so we can run per tensor.
- weight = layer.weight
- weight_scale = layer.weight_scale
- # If rocm, use float8_e4m3fnuz.
- if is_hip():
- weight, weight_scale, input_scale = \
- normalize_e4m3fn_to_e4m3fnuz(
- weight=weight,
- weight_scale=weight_scale,
- input_scale=layer.input_scale)
- if input_scale is not None:
- layer.input_scale = Parameter(input_scale,
- requires_grad=False)
- weight_scale, weight = requantize_with_max_scale(
- weight=weight,
- weight_scale=weight_scale,
- logical_widths=layer.logical_widths,
- )
- # Update layer with new values.
- layer.weight = Parameter(weight.t(), requires_grad=False)
- layer.weight_scale = Parameter(weight_scale, requires_grad=False)
- if self.quant_config.activation_scheme == "static":
- layer.input_scale = Parameter(layer.input_scale.max(),
- requires_grad=False)
- if self.use_marlin:
- prepare_fp8_layer_for_marlin(layer)
- # Activations not quantized for marlin.
- del layer.input_scale
- def apply(self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
- if self.use_marlin:
- return apply_fp8_marlin_linear(
- input=x,
- weight=layer.weight,
- weight_scale=layer.weight_scale,
- workspace=layer.workspace,
- size_n=layer.output_size_per_partition,
- size_k=layer.input_size_per_partition,
- bias=bias)
- return apply_fp8_linear(
- input=x,
- weight=layer.weight,
- weight_scale=layer.weight_scale,
- input_scale=layer.input_scale,
- bias=bias,
- cutlass_fp8_supported=self.cutlass_fp8_supported,
- use_per_token_if_dynamic=False)
- class Fp8MoEMethod(FusedMoEMethodBase):
- """MoE method for FP8.
- Supports loading FP8 checkpoints with static weight scale and
- dynamic/static activation scale.
- Also supports loading quantized FP16/BF16 model checkpoints with dynamic
- activation scaling. The weight scaling factor will be initialized after
- the model weights are loaded.
- Args:
- quant_config: The quantization config.
- """
- def __init__(self, quant_config: Fp8Config):
- self.quant_config = quant_config
- def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
- intermediate_size: int, params_dtype: torch.dtype,
- **extra_weight_attrs):
- if self.quant_config.is_checkpoint_fp8_serialized:
- params_dtype = torch.float8_e4m3fn
- # WEIGHTS
- w13_weight = torch.nn.Parameter(torch.empty(num_experts,
- 2 * intermediate_size,
- hidden_size,
- dtype=params_dtype),
- requires_grad=False)
- layer.register_parameter("w13_weight", w13_weight)
- set_weight_attrs(w13_weight, extra_weight_attrs)
- w2_weight = torch.nn.Parameter(torch.empty(num_experts,
- hidden_size,
- intermediate_size,
- dtype=params_dtype),
- requires_grad=False)
- layer.register_parameter("w2_weight", w2_weight)
- set_weight_attrs(w2_weight, extra_weight_attrs)
- # WEIGHT_SCALES
- # Allocate 2 scales for w1 and w3 respectively.
- # They will be combined to a single scale after weight loading.
- w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
- 2,
- dtype=torch.float32),
- requires_grad=False)
- layer.register_parameter("w13_weight_scale", w13_weight_scale)
- w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
- dtype=torch.float32),
- requires_grad=False)
- layer.register_parameter("w2_weight_scale", w2_weight_scale)
- # Add the quantization method used (per tensor/grouped/channel)
- # to ensure the weight scales are loaded in properly
- extra_weight_attrs.update(
- {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
- # If loading fp8 checkpoint, pass the weight loaders.
- # If loading an fp16 checkpoint, do not (we will quantize in
- # process_weights_after_loading()
- if self.quant_config.is_checkpoint_fp8_serialized:
- set_weight_attrs(w13_weight_scale, extra_weight_attrs)
- set_weight_attrs(w2_weight_scale, extra_weight_attrs)
- # INPUT_SCALES
- if self.quant_config.activation_scheme == "static":
- if not self.quant_config.is_checkpoint_fp8_serialized:
- raise ValueError(
- "Found static activation scheme for checkpoint that "
- "was not serialized fp8.")
- w13_input_scale = torch.nn.Parameter(torch.ones(
- num_experts, dtype=torch.float32),
- requires_grad=False)
- layer.register_parameter("w13_input_scale", w13_input_scale)
- set_weight_attrs(w13_input_scale, extra_weight_attrs)
- w2_input_scale = torch.nn.Parameter(torch.ones(
- num_experts, dtype=torch.float32),
- requires_grad=False)
- layer.register_parameter("w2_input_scale", w2_input_scale)
- set_weight_attrs(w2_input_scale, extra_weight_attrs)
- else:
- layer.w13_input_scale = None
- layer.w2_input_scale = None
- def process_weights_after_loading(self, layer: Module) -> None:
- # If checkpoint is fp16, quantize in place.
- if not self.quant_config.is_checkpoint_fp8_serialized:
- # If rocm, use float8_e4m3fnuz as dtype
- fp8_dtype = torch.float8_e4m3fnuz \
- if is_hip() else torch.float8_e4m3fn
- w13_weight = torch.empty_like(layer.w13_weight.data,
- dtype=fp8_dtype)
- w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
- # Re-initialize w13_scale because we directly quantize
- # merged w13 weights and generate a single scaling factor.
- layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
- layer.num_experts,
- dtype=torch.float32,
- device=w13_weight.device),
- requires_grad=False)
- for expert in range(layer.num_experts):
- w13_weight[expert, :, :], layer.w13_weight_scale[
- expert] = ops.scaled_fp8_quant(
- layer.w13_weight.data[expert, :, :])
- w2_weight[expert, :, :], layer.w2_weight_scale[
- expert] = ops.scaled_fp8_quant(
- layer.w2_weight.data[expert, :, :])
- layer.w13_weight = torch.nn.Parameter(w13_weight,
- requires_grad=False)
- layer.w2_weight = torch.nn.Parameter(w2_weight,
- requires_grad=False)
- return
- # If checkpoint is fp8, we need to handle that the
- # MoE kernels require single activation scale and single weight
- # scale for w13 per expert.
- else:
- # Fp8 moe kernels require a single activation scale.
- # We take the max of all the scales in case they differ.
- if self.quant_config.activation_scheme == "static":
- if (layer.w13_input_scale is None
- or layer.w2_input_scale is None):
- raise ValueError(
- "QuantConfig has static quantization, but found "
- "activation scales are None.")
- if (not all_close_1d(layer.w13_input_scale)
- or not all_close_1d(layer.w2_input_scale)):
- print_warning_once(
- "Found input_scales that are not equal for "
- "fp8 MoE layer. Using the maximum across experts "
- "for each layer. ")
- layer.w13_input_scale = torch.nn.Parameter(
- layer.w13_input_scale.max(), requires_grad=False)
- layer.w2_input_scale = torch.nn.Parameter(
- layer.w2_input_scale.max(), requires_grad=False)
- # If rocm, normalize the weights and scales to e4m3fnuz
- if is_hip():
- # Normalize the weights and scales
- w13_weight, w13_weight_scale, w13_input_scale = \
- normalize_e4m3fn_to_e4m3fnuz(
- layer.w13_weight, layer.w13_weight_scale,
- layer.w13_input_scale)
- w2_weight, w2_weight_scale, w2_input_scale = \
- normalize_e4m3fn_to_e4m3fnuz(
- layer.w2_weight, layer.w2_weight_scale,
- layer.w2_input_scale)
- # Reset the parameter
- layer.w13_weight = torch.nn.Parameter(w13_weight,
- requires_grad=False)
- layer.w13_weight_scale = torch.nn.Parameter(
- w13_weight_scale, requires_grad=False)
- if w13_input_scale is not None:
- layer.w13_input_scale = torch.nn.Parameter(
- w13_input_scale, requires_grad=False)
- layer.w2_weight = torch.nn.Parameter(w2_weight,
- requires_grad=False)
- layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
- requires_grad=False)
- if w2_input_scale is not None:
- layer.w2_input_scale = torch.nn.Parameter(
- w2_input_scale, requires_grad=False)
- # Fp8 moe kernel needs single weight scale for w13 per expert.
- # We take the max then dequant and requant each expert.
- assert layer.w13_weight_scale is not None
- shard_size = layer.intermediate_size_per_partition
- max_w13_scales = layer.w13_weight_scale.max(dim=1).values
- for expert_id in range(layer.num_experts):
- start = 0
- for shard_id in range(2):
- dq_weight = per_tensor_dequantize(
- layer.w13_weight[expert_id][start:start +
- shard_size, :],
- layer.w13_weight_scale[expert_id][shard_id])
- layer.w13_weight[expert_id][
- start:start + shard_size, :], _ = ops.scaled_fp8_quant(
- dq_weight, max_w13_scales[expert_id])
- start += shard_size
- layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
- requires_grad=False)
- return
- def apply(
- self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- router_logits: torch.Tensor,
- top_k: int,
- renormalize: bool,
- use_grouped_topk: bool,
- topk_group: Optional[int] = None,
- num_expert_group: Optional[int] = None,
- custom_routing_function: Optional[Callable] = None,
- ) -> torch.Tensor:
- from aphrodite.modeling.layers.fused_moe import fused_experts
- topk_weights, topk_ids = FusedMoE.select_experts(
- hidden_states=x,
- router_logits=router_logits,
- use_grouped_topk=use_grouped_topk,
- top_k=top_k,
- renormalize=renormalize,
- topk_group=topk_group,
- num_expert_group=num_expert_group,
- custom_routing_function=custom_routing_function)
- return fused_experts(x,
- layer.w13_weight,
- layer.w2_weight,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- inplace=True,
- use_fp8_w8a8=True,
- w1_scale=layer.w13_weight_scale,
- w2_scale=layer.w2_weight_scale,
- a1_scale=layer.w13_input_scale,
- a2_scale=layer.w2_input_scale)
- class Fp8KVCacheMethod(BaseKVCacheMethod):
- """
- Supports loading kv-cache scaling factors from FP8 checkpoints.
- """
- def __init__(self, quant_config: Fp8Config):
- super().__init__(quant_config)
|