瀏覽代碼

feat: fp8 quant

AlpinDale 10 月之前
父節點
當前提交
85a865cc00

+ 8 - 1
aphrodite/modeling/hf_downloader.py

@@ -130,11 +130,18 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
             )
     else:
         hf_folder = model_name_or_path
+
+    possible_config_filenames = quant_cls.get_config_filenames()
+
+    # If the quantization config is not found, use the default config.
+    if not possible_config_filenames:
+        return quant_cls()
+
     config_files = glob.glob(os.path.join(hf_folder, "*.json"))
 
     quant_config_files = [
         f for f in config_files if any(
-            f.endswith(x) for x in quant_cls.get_config_filenames())
+            f.endswith(x) for x in possible_config_filenames)
     ]
     if len(quant_config_files) == 0:
         raise ValueError(

+ 6 - 0
aphrodite/modeling/layers/linear.py

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional
 import torch
 import torch.nn.functional as F
 from loguru import logger
+from torch import nn
 from torch.nn.parameter import Parameter
 
 from aphrodite.distributed import (
@@ -138,6 +139,11 @@ class LinearMethodBase(ABC):
                           topk: int, renormalize: bool) -> torch.Tensor:
         """Apply the weights to the input tensor."""
         raise NotImplementedError
+    
+
+    def process_weights_after_loading(self, layer: nn.Module) -> None:
+        """Process the weights after loading."""
+        pass
 
 
 class UnquantizedLinearMethod(LinearMethodBase):

+ 4 - 0
aphrodite/modeling/loader.py

@@ -147,4 +147,8 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
                         2,
                     ),
                 ))
+        for _, module in model.named_modules():
+            linear_method = getattr(module, "linear_method", None)
+            if linear_method is not None:
+                linear_method.process_weights_after_loading(module)
     return model.eval()

+ 2 - 0
aphrodite/quantization/__init__.py

@@ -10,6 +10,7 @@ from aphrodite.quantization.bitsandbytes import \
     BitsandBytesConfig
 from aphrodite.quantization.eetq import EETQConfig
 from aphrodite.quantization.exl2 import Exl2Config
+from aphrodite.quantization.fp8 import FP8Config
 from aphrodite.quantization.gguf import GGUFConfig
 from aphrodite.quantization.gptq import GPTQConfig
 from aphrodite.quantization.marlin import MarlinConfig
@@ -30,6 +31,7 @@ QUANTIZATION_METHODS = {
     "bnb": BitsandBytesConfig,
     "eetq": EETQConfig,
     "exl2": Exl2Config,
+    "fp8": FP8Config,
     "gguf": GGUFConfig,
     "gptq": GPTQConfig,
     "quip": QuipConfig,

+ 133 - 0
aphrodite/quantization/fp8.py

@@ -0,0 +1,133 @@
+from typing import Any, Dict, List, Optional
+
+import torch
+from torch.nn import Module
+from torch.nn.parameter import Parameter
+
+from aphrodite.modeling.layers.linear import (LinearMethodBase,
+                                              set_weight_attrs)
+from aphrodite.quantization.base_config import QuantizationConfig
+
+
+class FP8Config(QuantizationConfig):
+    """Config class for FP8."""
+
+    @classmethod
+    def get_name(cls) -> str:
+        return "fp8"
+
+    @classmethod
+    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
+        return [torch.bfloat16, torch.half]
+
+    @classmethod
+    def get_min_capability(cls) -> int:
+        return 89
+    
+    @classmethod
+    def get_config_filenames(cls) -> List[str]:
+        return []
+
+    @classmethod
+    def from_config(cls, config: Dict[str, Any]) -> "FP8Config":
+        return cls()
+
+    def get_linear_method(self) -> "Fp8LinearMethod":
+        return Fp8LinearMethod(self)
+
+    def get_scaled_act_names(self) -> List[str]:
+        return []
+
+
+class Fp8LinearMethod(LinearMethodBase):
+    """Linear method for FP8.
+    We now support common FP16/BF16 model checkpoints ONLY. The weight
+    scaling factor will be initialized after the model weights are loaded.
+    Limitations:
+    1. Only support per-tensor quantization due to torch._scaled_mm support.
+    2. Only support float8_e4m3fn data type due to the limitation of
+       torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
+       
+    Args:
+        quant_config: The quantization config.
+    """
+
+    def __init__(self, quant_config: FP8Config):
+        self.quant_config = quant_config
+
+    def create_weights(
+        self,
+        layer: torch.nn.Module,
+        input_size_per_partition: int,
+        output_partition_sizes: List[int],
+        input_size: int,
+        output_size: int,
+        params_dtype: torch.dtype,
+        **extra_weight_attrs,
+    ):
+        output_size_per_partition = sum(output_partition_sizes)
+        weight = Parameter(torch.empty(output_size_per_partition,
+                                       input_size_per_partition,
+                                       dtype=params_dtype),
+                           requires_grad=False)
+        layer.register_parameter("weight", weight)
+        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
+        set_weight_attrs(weight, extra_weight_attrs)
+
+        w_scale = Parameter(
+            torch.empty(1, dtype=torch.float32),
+            requires_grad=False,
+        )
+        layer.register_parameter("weight_scaling_factor", w_scale)
+
+    def process_weights_after_loading(self, layer: Module) -> None:
+        # Although the linear_method is propagated to all layers,
+        # only linear layers invoke "create_weights". So we check
+        # whether "weight_scaling_facor" is registered to determine
+        # whether the layer is a linear layer that requires quantization.
+        if not hasattr(layer, "weight_scaling_factor"):
+            return
+        
+        qweight, weight_scale = per_tensor_quantize(layer.weight)
+        # torch._scaled_mm requires column-major in the second
+        # input (weight), so we transpose the quantized weight.
+        layer.weight = Parameter(qweight.t(), requires_grad=False)
+        layer.weight_scaling_factor.data.copy_(weight_scale)
+
+    def apply_weights(self,
+                      layer: torch.nn.Module,
+                      x: torch.Tensor,
+                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+        qinput, x_scale = per_tensor_quantize(x)
+        output, _ = torch._scaled_mm(
+            qinput,
+            layer.weight,
+            out_dtype=x.dtype,
+            scale_a=x_scale,
+            scale_b=layer.weight_scaling_factor,
+            bias=bias,
+        )
+        return output
+
+
+def per_tensor_quantize(tensor: torch.Tensor) -> tuple[torch.Tensor, float]:
+    """Quantize a tensor using per-tensor static scaling factor.
+    Args:
+        tensor: The input tensor.
+    """
+    finfo = torch.finfo(torch.float8_e4m3fn)
+    # Calculate the scale as dtype max divided by absmax.
+    # Since .abs() creates a new tensor, we use aminmax to get
+    # the min and max first and then calculate the absmax.
+    min_val, max_val = tensor.aminmax()
+    amax = min_val.abs().max(max_val.abs())
+    scale = finfo.max / amax.clamp(min=1e-12)
+    # scale and clamp the tensor to bring it to
+    # the representative range of float8 data type
+    # (as default cast is unsaturated)
+    qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
+    # Return both float8 data and the inverse scale (as float),
+    # as both required as inputs to torch._scaled_mm
+    qweight = qweight.to(torch.float8_e4m3fn)
+    scale = scale.float().reciprocal()
+    return qweight, scale