123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461 |
- from typing import Any, Dict, List, Optional
- import torch
- from loguru import logger
- from torch.nn import Module
- from torch.nn.parameter import Parameter
- from aphrodite import _custom_ops as ops
- from aphrodite.common.utils import print_warning_once
- from aphrodite.modeling.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
- fused_moe)
- from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
- from aphrodite.modeling.utils import set_weight_attrs
- from aphrodite.platforms import current_platform
- from aphrodite.quantization.base_config import (QuantizationConfig,
- QuantizeMethodBase)
- from aphrodite.quantization.utils.marlin_utils_fp8 import (
- apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
- from aphrodite.quantization.utils.w8a8_utils import (
- all_close_1d, apply_fp8_linear, create_per_tensor_scale_param,
- cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale)
- ACTIVATION_SCHEMES = ["static", "dynamic"]
- class Fp8Config(QuantizationConfig):
- """Config class for FP8."""
- def __init__(
- self,
- is_checkpoint_fp8_serialized: bool = False,
- activation_scheme: str = "dynamic",
- ) -> None:
- self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
- if is_checkpoint_fp8_serialized:
- logger.warning("Detected fp8 checkpoint. Please note that the "
- "format is experimental and subject to change.")
- if activation_scheme not in ACTIVATION_SCHEMES:
- raise ValueError(
- f"Unsupported activation scheme {activation_scheme}")
- self.activation_scheme = activation_scheme
- @classmethod
- def get_name(cls) -> str:
- return "fp8"
- @classmethod
- def get_supported_act_dtypes(cls) -> List[torch.dtype]:
- return [torch.bfloat16, torch.half]
- @classmethod
- def get_min_capability(cls) -> int:
- return 80
- @classmethod
- def get_config_filenames(cls) -> List[str]:
- return []
- @classmethod
- def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
- quant_method = cls.get_from_keys(config, ["quant_method"])
- is_checkpoint_fp8_serialized = ("fp8" in quant_method)
- activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
- return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
- activation_scheme=activation_scheme)
- def get_quant_method(self, layer: torch.nn.Module,
- prefix: str) -> Optional["QuantizeMethodBase"]:
- from aphrodite.attention.layer import \
- Attention # Avoid circular import
- if isinstance(layer, LinearBase):
- return Fp8LinearMethod(self)
- elif isinstance(layer, FusedMoE):
- return Fp8MoEMethod(self)
- elif isinstance(layer, Attention):
- return Fp8KVCacheMethod(self)
- return None
- def get_scaled_act_names(self) -> List[str]:
- return []
- class Fp8LinearMethod(LinearMethodBase):
- """Linear method for FP8.
- Supports loading FP8 checkpoints with static weight scale and
- dynamic/static activation scale.
- Also supports loading quantized FP16/BF16 model checkpoints with dynamic
- activation scaling. The weight scaling factor will be initialized after
- the model weights are loaded.
- Limitations:
- 1. Only support per-tensor quantization due to torch._scaled_mm support.
- 2. Only support float8_e4m3fn data type due to the limitation of
- torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
- Args:
- quant_config: The quantization config.
- """
- def __init__(self, quant_config: Fp8Config):
- self.quant_config = quant_config
- self.cutlass_fp8_supported = cutlass_fp8_supported()
- # For GPUs that lack FP8 hardware support, we can leverage the Marlin
- # kernel for fast weight-only FP8 quantization
- capability = current_platform.get_device_capability()
- capability = capability[0] * 10 + capability[1]
- self.use_marlin = capability < 89
- def create_weights(
- self,
- layer: torch.nn.Module,
- input_size_per_partition: int,
- output_partition_sizes: List[int],
- input_size: int,
- output_size: int,
- params_dtype: torch.dtype,
- **extra_weight_attrs,
- ):
- del input_size, output_size
- output_size_per_partition = sum(output_partition_sizes)
- layer.logical_widths = output_partition_sizes
- layer.input_size_per_partition = input_size_per_partition
- layer.output_size_per_partition = output_size_per_partition
- layer.orig_dtype = params_dtype
- # WEIGHT
- weight_dtype = (torch.float8_e4m3fn
- if self.quant_config.is_checkpoint_fp8_serialized else
- params_dtype)
- weight = Parameter(torch.empty(output_size_per_partition,
- input_size_per_partition,
- dtype=weight_dtype),
- requires_grad=False)
- layer.register_parameter("weight", weight)
- set_weight_attrs(weight, {
- **extra_weight_attrs,
- "input_dim": 1,
- "output_dim": 0,
- })
- # If checkpoint is serialized fp8, load them.
- # Otherwise, wait until process_weights_after_loading.
- if self.quant_config.is_checkpoint_fp8_serialized:
- # WEIGHT SCALE
- scale = create_per_tensor_scale_param(output_partition_sizes,
- **extra_weight_attrs)
- layer.register_parameter("weight_scale", scale)
- # INPUT ACTIVATION SCALE
- if self.quant_config.activation_scheme == "static":
- scale = create_per_tensor_scale_param(output_partition_sizes,
- **extra_weight_attrs)
- layer.register_parameter("input_scale", scale)
- def process_weights_after_loading(self, layer: Module) -> None:
- # If checkpoint not serialized fp8, quantize the weights.
- if not self.quant_config.is_checkpoint_fp8_serialized:
- qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
- scale=None)
- # Update the layer with the new values.
- layer.weight = Parameter(qweight.t(), requires_grad=False)
- layer.weight_scale = Parameter(weight_scale, requires_grad=False)
- layer.input_scale = None
- # If checkpoint is fp8, requantize the separately quantized logical
- # weights into a single fp8 weight with a single weight scale.
- else:
- # Dequant -> Quant with max scale.
- max_w_scale, weight = requantize_with_max_scale(
- weight=layer.weight,
- weight_scale=layer.weight_scale,
- logical_widths=layer.logical_widths,
- )
- # Update layer with new values.
- layer.weight = Parameter(weight.t(), requires_grad=False)
- layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
- if self.quant_config.activation_scheme == "static":
- layer.input_scale = Parameter(layer.input_scale.max(),
- requires_grad=False)
- else:
- layer.input_scale = None
- if self.use_marlin:
- prepare_fp8_layer_for_marlin(layer)
- # Activations not quantized for marlin.
- del layer.input_scale
- def apply(self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
- if self.use_marlin:
- return apply_fp8_marlin_linear(
- input=x,
- weight=layer.weight,
- weight_scale=layer.weight_scale,
- workspace=layer.workspace,
- size_n=layer.output_size_per_partition,
- size_k=layer.input_size_per_partition,
- bias=bias)
- return apply_fp8_linear(
- input=x,
- weight=layer.weight,
- weight_scale=layer.weight_scale,
- input_scale=layer.input_scale,
- bias=bias,
- cutlass_fp8_supported=self.cutlass_fp8_supported,
- use_per_token_if_dynamic=False)
- class Fp8MoEMethod(FusedMoEMethodBase):
- """MoE method for FP8.
- Supports loading FP8 checkpoints with static weight scale and
- dynamic/static activation scale.
- Also supports loading quantized FP16/BF16 model checkpoints with dynamic
- activation scaling. The weight scaling factor will be initialized after
- the model weights are loaded.
- Args:
- quant_config: The quantization config.
- """
- def __init__(self, quant_config: Fp8Config):
- self.quant_config = quant_config
- def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
- intermediate_size: int, params_dtype: torch.dtype,
- **extra_weight_attrs):
- if self.quant_config.is_checkpoint_fp8_serialized:
- params_dtype = torch.float8_e4m3fn
- # WEIGHTS
- w13_weight = torch.nn.Parameter(torch.empty(num_experts,
- 2 * intermediate_size,
- hidden_size,
- dtype=params_dtype),
- requires_grad=False)
- layer.register_parameter("w13_weight", w13_weight)
- set_weight_attrs(w13_weight, extra_weight_attrs)
- w2_weight = torch.nn.Parameter(torch.empty(num_experts,
- hidden_size,
- intermediate_size,
- dtype=params_dtype),
- requires_grad=False)
- layer.register_parameter("w2_weight", w2_weight)
- set_weight_attrs(w2_weight, extra_weight_attrs)
- # WEIGHT_SCALES
- # Allocate 2 scales for w1 and w3 respectively.
- # They will be combined to a single scale after weight loading.
- w13_scale = torch.nn.Parameter(torch.ones(num_experts,
- 2,
- dtype=torch.float32),
- requires_grad=False)
- layer.register_parameter("w13_scale", w13_scale)
- w2_scale = torch.nn.Parameter(torch.ones(num_experts,
- dtype=torch.float32),
- requires_grad=False)
- layer.register_parameter("w2_scale", w2_scale)
- # If loading fp8 checkpoint, pass the weight loaders.
- # If loading an fp16 checkpoint, do not (we will quantize in
- # process_weights_after_loading()
- if self.quant_config.is_checkpoint_fp8_serialized:
- set_weight_attrs(w13_scale, extra_weight_attrs)
- set_weight_attrs(w2_scale, extra_weight_attrs)
- # INPUT_SCALES
- if self.quant_config.activation_scheme == "static":
- if not self.quant_config.is_checkpoint_fp8_serialized:
- raise ValueError(
- "Found static activation scheme for checkpoint that "
- "was not serialized fp8.")
- a13_scale = torch.nn.Parameter(torch.ones(num_experts,
- dtype=torch.float32),
- requires_grad=False)
- layer.register_parameter("a13_scale", a13_scale)
- set_weight_attrs(a13_scale, extra_weight_attrs)
- a2_scale = torch.nn.Parameter(torch.ones(num_experts,
- dtype=torch.float32),
- requires_grad=False)
- layer.register_parameter("a2_scale", a2_scale)
- set_weight_attrs(a2_scale, extra_weight_attrs)
- else:
- layer.a13_scale = None
- layer.a2_scale = None
- def process_weights_after_loading(self, layer: Module) -> None:
- # If checkpoint is fp16, quantize in place.
- if not self.quant_config.is_checkpoint_fp8_serialized:
- w13_weight = torch.empty_like(layer.w13_weight.data,
- dtype=torch.float8_e4m3fn)
- w2_weight = torch.empty_like(layer.w2_weight.data,
- dtype=torch.float8_e4m3fn)
- # Re-initialize w13_scale because we directly quantize
- # merged w13 weights and generate a single scaling factor.
- layer.w13_scale = torch.nn.Parameter(torch.ones(
- layer.num_experts,
- dtype=torch.float32,
- device=w13_weight.device),
- requires_grad=False)
- for expert in range(layer.num_experts):
- w13_weight[expert, :, :], layer.w13_scale[
- expert] = ops.scaled_fp8_quant(
- layer.w13_weight.data[expert, :, :])
- w2_weight[expert, :, :], layer.w2_scale[
- expert] = ops.scaled_fp8_quant(
- layer.w2_weight.data[expert, :, :])
- layer.w13_weight = torch.nn.Parameter(w13_weight,
- requires_grad=False)
- layer.w2_weight = torch.nn.Parameter(w2_weight,
- requires_grad=False)
- return
- # If checkpoint is fp8, we need to handle that the
- # MoE kernels require single activation scale and single weight
- # scale for w13 per expert.
- else:
- # Fp8 moe kernels require a single activation scale.
- # We take the max of all the scales in case they differ.
- if self.quant_config.activation_scheme == "static":
- if layer.a13_scale is None or layer.a2_scale is None:
- raise ValueError(
- "QuantConfig has static quantization, but found "
- "activation scales are None.")
- if (not all_close_1d(layer.a13_scale)
- or not all_close_1d(layer.a2_scale)):
- print_warning_once(
- "Found input_scales that are not equal for "
- "fp8 MoE layer. Using the maximum across experts "
- "for each layer. ")
- layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(),
- requires_grad=False)
- layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(),
- requires_grad=False)
- # Fp8 moe kernel needs single weight scale for w13 per expert.
- # We take the max then dequant and requant each expert.
- assert layer.w13_scale is not None
- shard_size = layer.intermediate_size_per_partition
- max_w13_scales = layer.w13_scale.max(dim=1).values
- for expert_id in range(layer.num_experts):
- start = 0
- for shard_id in range(2):
- dq_weight = per_tensor_dequantize(
- layer.w13_weight[expert_id][start:start +
- shard_size, :],
- layer.w13_scale[expert_id][shard_id])
- layer.w13_weight[expert_id][
- start:start + shard_size, :], _ = ops.scaled_fp8_quant(
- dq_weight, max_w13_scales[expert_id])
- start += shard_size
- layer.w13_scale = torch.nn.Parameter(max_w13_scales,
- requires_grad=False)
- return
- def apply(self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- router_logits: torch.Tensor,
- top_k: int,
- renormalize: bool = True,
- use_grouped_topk: bool = False,
- num_expert_group: Optional[int] = None,
- topk_group: Optional[int] = None) -> torch.Tensor:
- return fused_moe(x,
- layer.w13_weight,
- layer.w2_weight,
- router_logits,
- top_k,
- renormalize=renormalize,
- inplace=True,
- use_fp8=True,
- w1_scale=layer.w13_scale,
- w2_scale=layer.w2_scale,
- a1_scale=layer.a13_scale,
- a2_scale=layer.a2_scale,
- use_grouped_topk=use_grouped_topk,
- num_expert_group=num_expert_group,
- topk_group=topk_group)
- class Fp8KVCacheMethod(QuantizeMethodBase):
- """Supports loading kv-cache scaling factors from FP8 checkpoints.
- """
- def __init__(self, quant_config: Fp8Config):
- self.quant_config = quant_config
- def create_weights(self, layer: torch.nn.Module):
- """Create "weight" (aka k_scale and v_scale) for an attention layer.
- Args:
- layer: The layer that is using the QuantizeMethodBase factory.
- """
- # Initialize the KV cache scales to -1.0, which is an invalid value.
- # If the k/v_scale appears in the checkpoint, it will be
- # overwritten when loading weights.
- layer.k_scale = Parameter(torch.tensor(-1.0), requires_grad=False)
- layer.v_scale = Parameter(torch.tensor(-1.0), requires_grad=False)
- def apply(self, layer: torch.nn.Module) -> torch.Tensor:
- raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
- def process_weights_after_loading(self, layer: Module) -> None:
- # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
- # regardless whether the kv-scale is available in the checkpoint.
- if layer.kv_cache_dtype != "auto":
- if layer.k_scale > 0.0 and layer.v_scale > 0.0:
- # We prefer to use separate k_scale and v_scale if present
- k_scale = layer.k_scale.to("cpu").tolist()
- v_scale = layer.v_scale.to("cpu").tolist()
- elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
- # If no scales were loaded (both scales are invalid negative
- # values), use the default value of 1.0
- k_scale = Parameter(torch.tensor(1.0), requires_grad=False)
- v_scale = Parameter(torch.tensor(1.0), requires_grad=False)
- else:
- # If we find a single kv_scale in the checkpoint, we remap
- # kv_scale to k_scale during weight loading, and duplicate
- # k_scale to v_scale here
- assert layer.k_scale > 0.0
- scale_to_duplicate = max(layer.k_scale, layer.v_scale)
- k_scale = scale_to_duplicate.to("cpu").tolist()
- v_scale = scale_to_duplicate.to("cpu").tolist()
- if not isinstance(k_scale, float) or not isinstance(
- v_scale, float):
- raise ValueError("Only support per-tensor scaling factor "
- "for fp8 KV cache")
- # These are used in the final Attention.forward()
- layer._k_scale = k_scale
- layer._v_scale = v_scale
- if (layer._k_scale == 1.0 and layer._v_scale == 1.0
- and "e5m2" not in layer.kv_cache_dtype):
- print_warning_once(
- "Using KV cache scaling factor 1.0 for fp8_e4m3. This "
- "may cause accuracy issues. Please make sure k/v_scale "
- "scaling factors are available in the fp8 checkpoint.")
- del layer.k_scale
- del layer.v_scale
|