from typing import Any, Dict, List, Optional import torch from loguru import logger from torch.nn import Module from torch.nn.parameter import Parameter from aphrodite import _custom_ops as ops from aphrodite.common.utils import print_warning_once from aphrodite.modeling.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, fused_moe) from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase from aphrodite.modeling.utils import set_weight_attrs from aphrodite.platforms import current_platform from aphrodite.quantization.base_config import (QuantizationConfig, QuantizeMethodBase) from aphrodite.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) from aphrodite.quantization.utils.w8a8_utils import ( all_close_1d, apply_fp8_linear, create_per_tensor_scale_param, cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale) ACTIVATION_SCHEMES = ["static", "dynamic"] class Fp8Config(QuantizationConfig): """Config class for FP8.""" def __init__( self, is_checkpoint_fp8_serialized: bool = False, activation_scheme: str = "dynamic", ) -> None: self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized if is_checkpoint_fp8_serialized: logger.warning("Detected fp8 checkpoint. Please note that the " "format is experimental and subject to change.") if activation_scheme not in ACTIVATION_SCHEMES: raise ValueError( f"Unsupported activation scheme {activation_scheme}") self.activation_scheme = activation_scheme @classmethod def get_name(cls) -> str: return "fp8" @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: return [torch.bfloat16, torch.half] @classmethod def get_min_capability(cls) -> int: return 80 @classmethod def get_config_filenames(cls) -> List[str]: return [] @classmethod def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": quant_method = cls.get_from_keys(config, ["quant_method"]) is_checkpoint_fp8_serialized = ("fp8" in quant_method) activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, activation_scheme=activation_scheme) def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: from aphrodite.attention.layer import \ Attention # Avoid circular import if isinstance(layer, LinearBase): return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): return Fp8MoEMethod(self) elif isinstance(layer, Attention): return Fp8KVCacheMethod(self) return None def get_scaled_act_names(self) -> List[str]: return [] class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. Supports loading FP8 checkpoints with static weight scale and dynamic/static activation scale. Also supports loading quantized FP16/BF16 model checkpoints with dynamic activation scaling. The weight scaling factor will be initialized after the model weights are loaded. Limitations: 1. Only support per-tensor quantization due to torch._scaled_mm support. 2. Only support float8_e4m3fn data type due to the limitation of torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) Args: quant_config: The quantization config. """ def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config self.cutlass_fp8_supported = cutlass_fp8_supported() # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] self.use_marlin = capability < 89 def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): del input_size, output_size output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition layer.orig_dtype = params_dtype # WEIGHT weight_dtype = (torch.float8_e4m3fn if self.quant_config.is_checkpoint_fp8_serialized else params_dtype) weight = Parameter(torch.empty(output_size_per_partition, input_size_per_partition, dtype=weight_dtype), requires_grad=False) layer.register_parameter("weight", weight) set_weight_attrs(weight, { **extra_weight_attrs, "input_dim": 1, "output_dim": 0, }) # If checkpoint is serialized fp8, load them. # Otherwise, wait until process_weights_after_loading. if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE scale = create_per_tensor_scale_param(output_partition_sizes, **extra_weight_attrs) layer.register_parameter("weight_scale", scale) # INPUT ACTIVATION SCALE if self.quant_config.activation_scheme == "static": scale = create_per_tensor_scale_param(output_partition_sizes, **extra_weight_attrs) layer.register_parameter("input_scale", scale) def process_weights_after_loading(self, layer: Module) -> None: # If checkpoint not serialized fp8, quantize the weights. if not self.quant_config.is_checkpoint_fp8_serialized: qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) # Update the layer with the new values. layer.weight = Parameter(qweight.t(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.input_scale = None # If checkpoint is fp8, requantize the separately quantized logical # weights into a single fp8 weight with a single weight scale. else: # Dequant -> Quant with max scale. max_w_scale, weight = requantize_with_max_scale( weight=layer.weight, weight_scale=layer.weight_scale, logical_widths=layer.logical_widths, ) # Update layer with new values. layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(max_w_scale, requires_grad=False) if self.quant_config.activation_scheme == "static": layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) else: layer.input_scale = None if self.use_marlin: prepare_fp8_layer_for_marlin(layer) # Activations not quantized for marlin. del layer.input_scale def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: if self.use_marlin: return apply_fp8_marlin_linear( input=x, weight=layer.weight, weight_scale=layer.weight_scale, workspace=layer.workspace, size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition, bias=bias) return apply_fp8_linear( input=x, weight=layer.weight, weight_scale=layer.weight_scale, input_scale=layer.input_scale, bias=bias, cutlass_fp8_supported=self.cutlass_fp8_supported, use_per_token_if_dynamic=False) class Fp8MoEMethod(FusedMoEMethodBase): """MoE method for FP8. Supports loading FP8 checkpoints with static weight scale and dynamic/static activation scale. Also supports loading quantized FP16/BF16 model checkpoints with dynamic activation scaling. The weight scaling factor will be initialized after the model weights are loaded. Args: quant_config: The quantization config. """ def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, **extra_weight_attrs): if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn # WEIGHTS w13_weight = torch.nn.Parameter(torch.empty(num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype), requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) w2_weight = torch.nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size, dtype=params_dtype), requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES # Allocate 2 scales for w1 and w3 respectively. # They will be combined to a single scale after weight loading. w13_scale = torch.nn.Parameter(torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False) layer.register_parameter("w13_scale", w13_scale) w2_scale = torch.nn.Parameter(torch.ones(num_experts, dtype=torch.float32), requires_grad=False) layer.register_parameter("w2_scale", w2_scale) # If loading fp8 checkpoint, pass the weight loaders. # If loading an fp16 checkpoint, do not (we will quantize in # process_weights_after_loading() if self.quant_config.is_checkpoint_fp8_serialized: set_weight_attrs(w13_scale, extra_weight_attrs) set_weight_attrs(w2_scale, extra_weight_attrs) # INPUT_SCALES if self.quant_config.activation_scheme == "static": if not self.quant_config.is_checkpoint_fp8_serialized: raise ValueError( "Found static activation scheme for checkpoint that " "was not serialized fp8.") a13_scale = torch.nn.Parameter(torch.ones(num_experts, dtype=torch.float32), requires_grad=False) layer.register_parameter("a13_scale", a13_scale) set_weight_attrs(a13_scale, extra_weight_attrs) a2_scale = torch.nn.Parameter(torch.ones(num_experts, dtype=torch.float32), requires_grad=False) layer.register_parameter("a2_scale", a2_scale) set_weight_attrs(a2_scale, extra_weight_attrs) else: layer.a13_scale = None layer.a2_scale = None def process_weights_after_loading(self, layer: Module) -> None: # If checkpoint is fp16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: w13_weight = torch.empty_like(layer.w13_weight.data, dtype=torch.float8_e4m3fn) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=torch.float8_e4m3fn) # Re-initialize w13_scale because we directly quantize # merged w13 weights and generate a single scaling factor. layer.w13_scale = torch.nn.Parameter(torch.ones( layer.num_experts, dtype=torch.float32, device=w13_weight.device), requires_grad=False) for expert in range(layer.num_experts): w13_weight[expert, :, :], layer.w13_scale[ expert] = ops.scaled_fp8_quant( layer.w13_weight.data[expert, :, :]) w2_weight[expert, :, :], layer.w2_scale[ expert] = ops.scaled_fp8_quant( layer.w2_weight.data[expert, :, :]) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) return # If checkpoint is fp8, we need to handle that the # MoE kernels require single activation scale and single weight # scale for w13 per expert. else: # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. if self.quant_config.activation_scheme == "static": if layer.a13_scale is None or layer.a2_scale is None: raise ValueError( "QuantConfig has static quantization, but found " "activation scales are None.") if (not all_close_1d(layer.a13_scale) or not all_close_1d(layer.a2_scale)): print_warning_once( "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " "for each layer. ") layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(), requires_grad=False) layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(), requires_grad=False) # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max then dequant and requant each expert. assert layer.w13_scale is not None shard_size = layer.intermediate_size_per_partition max_w13_scales = layer.w13_scale.max(dim=1).values for expert_id in range(layer.num_experts): start = 0 for shard_id in range(2): dq_weight = per_tensor_dequantize( layer.w13_weight[expert_id][start:start + shard_size, :], layer.w13_scale[expert_id][shard_id]) layer.w13_weight[expert_id][ start:start + shard_size, :], _ = ops.scaled_fp8_quant( dq_weight, max_w13_scales[expert_id]) start += shard_size layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) return def apply(self, layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, renormalize: bool = True, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None) -> torch.Tensor: return fused_moe(x, layer.w13_weight, layer.w2_weight, router_logits, top_k, renormalize=renormalize, inplace=True, use_fp8=True, w1_scale=layer.w13_scale, w2_scale=layer.w2_scale, a1_scale=layer.a13_scale, a2_scale=layer.a2_scale, use_grouped_topk=use_grouped_topk, num_expert_group=num_expert_group, topk_group=topk_group) class Fp8KVCacheMethod(QuantizeMethodBase): """Supports loading kv-cache scaling factors from FP8 checkpoints. """ def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module): """Create "weight" (aka k_scale and v_scale) for an attention layer. Args: layer: The layer that is using the QuantizeMethodBase factory. """ # Initialize the KV cache scales to -1.0, which is an invalid value. # If the k/v_scale appears in the checkpoint, it will be # overwritten when loading weights. layer.k_scale = Parameter(torch.tensor(-1.0), requires_grad=False) layer.v_scale = Parameter(torch.tensor(-1.0), requires_grad=False) def apply(self, layer: torch.nn.Module) -> torch.Tensor: raise RuntimeError("Fp8KVCacheMethod.apply should not be called.") def process_weights_after_loading(self, layer: Module) -> None: # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 # regardless whether the kv-scale is available in the checkpoint. if layer.kv_cache_dtype != "auto": if layer.k_scale > 0.0 and layer.v_scale > 0.0: # We prefer to use separate k_scale and v_scale if present k_scale = layer.k_scale.to("cpu").tolist() v_scale = layer.v_scale.to("cpu").tolist() elif layer.k_scale < 0.0 and layer.v_scale < 0.0: # If no scales were loaded (both scales are invalid negative # values), use the default value of 1.0 k_scale = Parameter(torch.tensor(1.0), requires_grad=False) v_scale = Parameter(torch.tensor(1.0), requires_grad=False) else: # If we find a single kv_scale in the checkpoint, we remap # kv_scale to k_scale during weight loading, and duplicate # k_scale to v_scale here assert layer.k_scale > 0.0 scale_to_duplicate = max(layer.k_scale, layer.v_scale) k_scale = scale_to_duplicate.to("cpu").tolist() v_scale = scale_to_duplicate.to("cpu").tolist() if not isinstance(k_scale, float) or not isinstance( v_scale, float): raise ValueError("Only support per-tensor scaling factor " "for fp8 KV cache") # These are used in the final Attention.forward() layer._k_scale = k_scale layer._v_scale = v_scale if (layer._k_scale == 1.0 and layer._v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype): print_warning_once( "Using KV cache scaling factor 1.0 for fp8_e4m3. This " "may cause accuracy issues. Please make sure k/v_scale " "scaling factors are available in the fp8 checkpoint.") del layer.k_scale del layer.v_scale