فهرست منبع

quants: add support for NVIDIA's ModelOpt checkpoints (#1013)

AlpinDale 2 ماه پیش
والد
کامیت
dcb36de9c4

+ 9 - 8
aphrodite/common/config.py

@@ -54,21 +54,22 @@ _PP_SUPPORTED_MODELS = [
 ]
 
 _OPTIMIZED_QUANTS = [
+    "awq_marlin",
+    "compressed-tensors",
+    "compressed_tensors",
+    "experts_int8", 
+    "fbgemm_fp8",
     "fp2",
-    "fp3",
+    "fp3", 
     "fp4",
     "fp5",
     "fp6",
     "fp7",
     "fp8",
-    "marlin",
-    "gptq_marlin_24",
     "gptq_marlin",
-    "awq_marlin",
-    "fbgemm_fp8",
-    "compressed-tensors",
-    "compressed_tensors",
-    "experts_int8",
+    "gptq_marlin_24",
+    "marlin",
+    "modelopt",
     "quant_llm",
 ]
 

+ 1 - 0
aphrodite/modeling/layers/linear.py

@@ -36,6 +36,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
     "GPTQMarlinLinearMethod",
     "HQQMarlinMethod",
     "MarlinLinearMethod",
+    "ModelOptFp8LinearMethod",
     "QQQLinearMethod",
     "TPUInt8LinearMethod",
 ]

+ 7 - 1
aphrodite/modeling/model_loader/weight_utils.py

@@ -184,9 +184,15 @@ def get_quant_config(model_config: ModelConfig,
     quant_config_file = quant_config_files[0]
     with open(quant_config_file, "r") as f:
         config = json.load(f)
-
         if model_config.quantization == "bitsandbytes":
             config["adapter_name_or_path"] = model_name_or_path
+        elif model_config.quantization == "modelopt":
+            if config["producer"]["name"] == "modelopt":
+                return quant_cls.from_config(config)
+            else:
+                raise ValueError(
+                    f"Unsupported quantization config"
+                    f" found for {model_config.quantization} in {f}.")
 
     return quant_cls.from_config(config)
 

+ 2 - 0
aphrodite/quantization/__init__.py

@@ -19,6 +19,7 @@ from aphrodite.quantization.gptq_marlin import GPTQMarlinConfig
 from aphrodite.quantization.gptq_marlin_24 import GPTQMarlin24Config
 from aphrodite.quantization.hqq_marlin import HQQMarlinConfig
 from aphrodite.quantization.marlin import MarlinConfig
+from aphrodite.quantization.modelopt import ModelOptFp8Config
 from aphrodite.quantization.neuron_quant import NeuronQuantConfig
 from aphrodite.quantization.qqq import QQQConfig
 from aphrodite.quantization.quip import QuipConfig
@@ -34,6 +35,7 @@ QUANTIZATION_METHODS = {
     "fp8": Fp8Config,
     "quant_llm": QuantLLMFPConfig,
     "fbgemm_fp8": FBGEMMFp8Config,
+    "modelopt": ModelOptFp8Config,
     "gguf": GGUFConfig,
     # The order of gptq methods is important for config.py iteration over
     # override_quantization_method(..)

+ 180 - 0
aphrodite/quantization/modelopt.py

@@ -0,0 +1,180 @@
+from typing import Any, Dict, List, Optional
+
+import torch
+from loguru import logger
+from torch.nn import Module
+from torch.nn.parameter import Parameter
+
+from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
+from aphrodite.modeling.parameter import (ModelWeightParameter,
+                                          PerTensorScaleParameter)
+from aphrodite.quantization.base_config import (QuantizationConfig,
+                                                QuantizeMethodBase)
+from aphrodite.quantization.kv_cache import BaseKVCacheMethod
+from aphrodite.quantization.utils.w8a8_utils import (apply_fp8_linear,
+                                                     cutlass_fp8_supported,
+                                                     requantize_with_max_scale)
+
+ACTIVATION_SCHEMES = ["static"]
+
+
+class ModelOptFp8Config(QuantizationConfig):
+    """Config class for ModelOpt FP8."""
+
+    def __init__(
+        self,
+        is_checkpoint_fp8_serialized: bool = False,
+    ) -> None:
+        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
+        if is_checkpoint_fp8_serialized:
+            logger.warning(
+                "Detected ModelOpt fp8 checkpoint. Please note that"
+                " the format is experimental and could change."
+            )
+
+    @classmethod
+    def get_name(cls) -> str:
+        return "modelopt"
+
+    @classmethod
+    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
+        return [torch.bfloat16, torch.half]
+
+    @classmethod
+    def get_min_capability(cls) -> int:
+        return 89
+
+    @classmethod
+    def get_config_filenames(cls) -> List[str]:
+        return ["hf_quant_config.json"]
+
+    @classmethod
+    def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
+        quant_config = cls.get_from_keys(config, ["quantization"])
+        quant_method = quant_config["quant_algo"]
+        is_checkpoint_fp8_serialized = "FP8" in quant_method
+        if not is_checkpoint_fp8_serialized:
+            raise ValueError(
+                "ModelOpt currently only supports static FP8"
+                "quantization in Aphrodite. Please check the "
+                "`hf_quant_config.json` file for your model's "
+                "quant configuration."
+            )
+        return cls(is_checkpoint_fp8_serialized)
+
+    def get_quant_method(
+        self, layer: torch.nn.Module, prefix: str
+    ) -> Optional["QuantizeMethodBase"]:
+        from aphrodite.attention.layer import (
+            Attention)  # Avoid circular import
+
+        if isinstance(layer, LinearBase):
+            return ModelOptFp8LinearMethod(self)
+        elif isinstance(layer, Attention):
+            return ModelOptFp8KVCacheMethod(self)
+        return None
+
+    def get_scaled_act_names(self) -> List[str]:
+        return []
+
+
+class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
+    """
+    Supports loading kv-cache scaling factors from FP8 checkpoints.
+    """
+
+    def __init__(self, quant_config: ModelOptFp8Config):
+        super().__init__(quant_config)
+
+
+class ModelOptFp8LinearMethod(LinearMethodBase):
+    """Linear method for Model Optimizer static quantization.
+    Supports loading FP8 checkpoints with static weight scale and
+    activation scale. Future support might be added for dynamic
+    scales.
+    Limitations:
+    1. Only support per-tensor quantization due to torch._scaled_mm support.
+    2. Only support float8_e4m3fn datatype
+        Args: quant_config: The ModelOpt quantization config.
+    """
+
+    def __init__(self, quant_config: ModelOptFp8Config):
+        self.quant_config = quant_config
+        self.cutlass_fp8_supported = cutlass_fp8_supported()
+
+    def create_weights(
+        self,
+        layer: torch.nn.Module,
+        input_size_per_partition: int,
+        output_partition_sizes: List[int],
+        input_size: int,
+        output_size: int,
+        params_dtype: torch.dtype,
+        **extra_weight_attrs,
+    ):
+        del input_size, output_size
+        output_size_per_partition = sum(output_partition_sizes)
+        weight_loader = extra_weight_attrs.get("weight_loader")
+        layer.logical_widths = output_partition_sizes
+        layer.input_size_per_partition = input_size_per_partition
+        layer.output_size_per_partition = output_size_per_partition
+        weight_dtype = (
+            torch.float8_e4m3fn
+            if self.quant_config.is_checkpoint_fp8_serialized
+            else params_dtype
+        )
+        weight = ModelWeightParameter(
+            data=torch.empty(
+                output_size_per_partition,
+                input_size_per_partition,
+                dtype=weight_dtype,
+            ),
+            input_dim=1,
+            output_dim=0,
+            weight_loader=weight_loader,
+        )
+        layer.register_parameter("weight", weight)
+        if self.quant_config.is_checkpoint_fp8_serialized:
+            # WEIGHT SCALE
+            weight_scale = PerTensorScaleParameter(
+                data=torch.empty(
+                    len(output_partition_sizes), dtype=torch.float32
+                ),
+                weight_loader=weight_loader,
+            )
+            weight_scale[:] = torch.finfo(torch.float32).min
+            layer.register_parameter("weight_scale", weight_scale)
+            # INPUT SCALE
+            scale = PerTensorScaleParameter(
+                data=torch.empty(
+                    len(output_partition_sizes), dtype=torch.float32
+                ),
+                weight_loader=weight_loader,
+            )
+            scale[:] = torch.finfo(torch.float32).min
+            layer.register_parameter("input_scale", scale)
+
+    def process_weights_after_loading(self, layer: Module) -> None:
+        max_w_scale, weight = requantize_with_max_scale(
+            layer.weight, layer.weight_scale, layer.logical_widths
+        )
+        layer.weight = Parameter(weight.t(), requires_grad=False)
+        layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
+        layer.input_scale = Parameter(
+            layer.input_scale.max(), requires_grad=False
+        )
+
+    def apply(
+        self,
+        layer: torch.nn.Module,
+        x: torch.Tensor,
+        bias: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        return apply_fp8_linear(
+            input=x,
+            weight=layer.weight,
+            weight_scale=layer.weight_scale,
+            input_scale=layer.input_scale,
+            bias=bias,
+            cutlass_fp8_supported=self.cutlass_fp8_supported,
+        )