Răsfoiți Sursa

static and dynamic fp8

AlpinDale 8 luni în urmă
părinte
comite
7d23892501
2 a modificat fișierele cu 208 adăugiri și 28 ștergeri
  1. 31 0
      aphrodite/modeling/layers/linear.py
  2. 177 28
      aphrodite/quantization/fp8.py

+ 31 - 0
aphrodite/modeling/layers/linear.py

@@ -236,6 +236,9 @@ class ColumnParallelLinear(LinearBase):
             self.register_parameter("bias", None)
 
     def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
+        # Special case for Fp8 scales.
+        fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
+                                           None)
         tp_rank = get_tensor_model_parallel_rank()
         output_dim = getattr(param, "output_dim", None)
         param_data = param.data
@@ -244,6 +247,11 @@ class ColumnParallelLinear(LinearBase):
             start_idx = tp_rank * shard_size
             loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                  shard_size)
+        # Special case for Fp8 scales.
+        elif fp8_scales_shard_indexer is not None:
+            param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
+                                                                 loaded_weight,
+                                                                 shard_id=0)
         assert param_data.shape == loaded_weight.shape
         param_data.copy_(loaded_weight)
 
@@ -306,7 +314,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
 
         param_data = param.data
         output_dim = getattr(param, "output_dim", None)
+        # Special case for AQLM codebooks.
         is_metadata = getattr(param, "is_metadata", False)
+        # Special case for Fp8 scales.
+        fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
+                                           None)
         if loaded_shard_id is None:
             # Loaded weight is already packed.
             if output_dim is None:
@@ -364,6 +376,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
             shard_size = loaded_weight.shape[0]
             shard_offset = loaded_shard_id * shard_size
             param_data = param_data.narrow(0, shard_offset, shard_size)
+        # Special case for Fp8 scales.
+        elif fp8_scales_shard_indexer is not None:
+            param_data, loaded_weight = fp8_scales_shard_indexer(
+                param_data, loaded_weight, loaded_shard_id)
         else:
             ignore_warning = getattr(param, "ignore_warning", False)
             if not ignore_warning:
@@ -445,6 +461,9 @@ class QKVParallelLinear(ColumnParallelLinear):
         param_data = param.data
         output_dim = getattr(param, "output_dim", None)
         is_metadata = getattr(param, "is_metadata", False)
+        # Special case for Fp8 scales.
+        fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
+                                           None)
 
         if loaded_shard_id is None:
             # Loaded weight is already packed.
@@ -518,6 +537,10 @@ class QKVParallelLinear(ColumnParallelLinear):
             shard_index = ["q", "k", "v"].index(loaded_shard_id)
             param_data = param_data.narrow(0, shard_index * shard_size,
                                            shard_size)
+        # Special case for Fp8 scales.
+        elif fp8_scales_shard_indexer is not None:
+            param_data, loaded_weight = fp8_scales_shard_indexer(
+                param_data, loaded_weight, loaded_shard_id)
         else:
             ignore_warning = getattr(param, "ignore_warning", False)
             if not ignore_warning:
@@ -598,6 +621,9 @@ class RowParallelLinear(LinearBase):
             self.register_parameter("bias", None)
 
     def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
+        # Special case for Fp8 scales.
+        fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
+                                           None)
         tp_rank = get_tensor_model_parallel_rank()
         input_dim = getattr(param, "input_dim", None)
         param_data = param.data
@@ -606,6 +632,11 @@ class RowParallelLinear(LinearBase):
             start_idx = tp_rank * shard_size
             loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                  shard_size)
+        # Special case for Fp8 scales.
+        elif fp8_scales_shard_indexer is not None:
+            param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
+                                                                 loaded_weight,
+                                                                 shard_id=0)
         assert param_data.shape == loaded_weight.shape
         param_data.copy_(loaded_weight)
 

+ 177 - 28
aphrodite/quantization/fp8.py

@@ -1,28 +1,51 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
 from contextlib import suppress
-from typing import Any, Dict, List, Optional
 
 import torch
 from torch.nn import Module
 from torch.nn.parameter import Parameter
+from loguru import logger
 
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
+from aphrodite.quantization.base_config import (QuantizationConfig)
 from aphrodite.modeling.utils import set_weight_attrs
-from aphrodite.quantization.base_config import (QuantizationConfig,
-                                                QuantizeMethodBase)
 
 HAS_QUANTS = False
 with suppress(ImportError):
     from aphrodite._quant_C import quant_ops as ops
     HAS_QUANTS = True
 
+ACTIVATION_SCHEMES = ["static", "dynamic"]
+
+
+def scaled_fp8_quant(
+        input: torch.Tensor,
+        scale: Optional[torch.Tensor] = None
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
+    if scale is None:
+        scale = torch.zeros(1, device=input.device, dtype=torch.float32)
+        ops.dynamic_scaled_fp8_quant(output, input, scale)
+    else:
+        ops.static_scaled_fp8_quant(output, input, scale)
+    return output, scale
+
 
 class Fp8Config(QuantizationConfig):
     """Config class for FP8."""
 
     def __init__(
         self,
+        is_checkpoint_fp8_serialized: bool = False,
         activation_scheme: str = "dynamic",
     ) -> None:
+        self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
+        if is_checkpoint_fp8_serialized:
+            logger.warning("Detected fp8 checkpoint. Please note that the "
+                           "format is experimental and subject to change.")
+        if activation_scheme not in ACTIVATION_SCHEMES:
+            raise ValueError(
+                f"Unsupported activation scheme {activation_scheme}")
         self.activation_scheme = activation_scheme
 
     @classmethod
@@ -43,11 +66,14 @@ class Fp8Config(QuantizationConfig):
 
     @classmethod
     def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
+        quant_method = cls.get_from_keys(config, ["quant_method"])
+        is_checkpoint_fp8_serialized = ("fp8" in quant_method)
         activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
-        return cls(activation_scheme)
+        return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
+                   activation_scheme=activation_scheme)
 
     def get_quant_method(
-            self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
+            self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]:
         if isinstance(layer, LinearBase):
             return Fp8LinearMethod(self)
         return None
@@ -58,9 +84,11 @@ class Fp8Config(QuantizationConfig):
 
 class Fp8LinearMethod(LinearMethodBase):
     """Linear method for FP8.
-    We now support common FP16/BF16 model checkpoints ONLY. The weight
-    scaling factor will be initialized after the model weights are loaded.
-
+    Supports loading FP8 checkpoints with static weight scale and
+    dynamic/static activation scale.
+    Also supports loading quantized FP16/BF16 model checkpoints with dynamic
+    activation scaling. The weight scaling factor will be initialized after
+    the model weights are loaded.
     Limitations:
     1. Only support per-tensor quantization due to torch._scaled_mm support.
     2. Only support float8_e4m3fn data type due to the limitation of
@@ -75,6 +103,24 @@ class Fp8LinearMethod(LinearMethodBase):
             raise ImportError("Could not find the quantization kernels.")
         self.quant_config = quant_config
 
+    def _create_scale_param(
+        self,
+        scale_name: str,
+        layer: torch.nn.Module,
+        output_partition_sizes: List[int],
+        **extra_weight_attrs,
+    ) -> None:
+        scale = Parameter(torch.empty(len(output_partition_sizes),
+                                      dtype=torch.float32),
+                          requires_grad=False)
+        layer.register_parameter(scale_name, scale)
+        set_weight_attrs(
+            scale, {
+                **extra_weight_attrs,
+                "fp8_scales_shard_indexer":
+                self.scales_shard_indexer,
+            })
+
     def create_weights(
         self,
         layer: torch.nn.Module,
@@ -85,46 +131,149 @@ class Fp8LinearMethod(LinearMethodBase):
         params_dtype: torch.dtype,
         **extra_weight_attrs,
     ):
+        del input_size, output_size
         output_size_per_partition = sum(output_partition_sizes)
+
+        layer.process_after_load = True
+        layer.logical_widths = output_partition_sizes
+
+        # WEIGHT
+        weight_dtype = (torch.float8_e4m3fn
+                        if self.quant_config.is_checkpoint_fp8_serialized else
+                        params_dtype)
         weight = Parameter(torch.empty(output_size_per_partition,
                                        input_size_per_partition,
-                                       dtype=params_dtype),
+                                       dtype=weight_dtype),
                            requires_grad=False)
         layer.register_parameter("weight", weight)
-        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
-        set_weight_attrs(weight, extra_weight_attrs)
+        set_weight_attrs(weight, {
+            **extra_weight_attrs,
+            "input_dim": 1,
+            "output_dim": 0,
+        })
 
-        w_scale = Parameter(
-            torch.empty(1, dtype=torch.float32),
-            requires_grad=False,
-        )
-        layer.register_parameter("weight_scaling_factor", w_scale)
+        # If checkpoint is serialized fp8, load them.
+        # Otherwise, wait until process_weights_after_loading.
+        if self.quant_config.is_checkpoint_fp8_serialized:
+            # WEIGHT SCALE
+            self._create_scale_param(
+                scale_name="weight_scale",
+                layer=layer,
+                output_partition_sizes=output_partition_sizes,
+                **extra_weight_attrs)
+
+            # ACTIVATION SCALE
+            if self.quant_config.activation_scheme == "static":
+                self._create_scale_param(
+                    scale_name="act_scale",
+                    layer=layer,
+                    output_partition_sizes=output_partition_sizes,
+                    **extra_weight_attrs)
+
+    def scales_shard_indexer(
+            self, param: torch.Tensor, loaded_weight: torch.Tensor,
+            shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]:
+        qkv_idxs = {"q": 0, "k": 1, "v": 2}
+
+        if isinstance(shard_id, int):
+            pass
+        elif isinstance(shard_id, str):
+            if shard_id not in qkv_idxs:
+                raise ValueError(f"Unknown shard_id: {shard_id}")
+            shard_id = qkv_idxs[shard_id]
+        else:
+            ValueError(f"Shard id must be int or str but got {type(shard_id)}")
+
+        return param[shard_id], loaded_weight
 
     def process_weights_after_loading(self, layer: Module) -> None:
-        # Although the linear_method is propagated to all layers,
-        # only linear layers invoke "create_weights". So we check
-        # whether "weight_scaling_facor" is registered to determine
-        # whether the layer is a linear layer that requires quantization.
-        if not hasattr(layer, "weight_scaling_factor"):
+        if (not hasattr(layer, "process_after_load")
+                or not layer.process_after_load):
+            return
+
+        # If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
+        if not self.quant_config.is_checkpoint_fp8_serialized:
+            qweight, weight_scale = scaled_fp8_quant(layer.weight, scale=None)
+            layer.weight = Parameter(qweight.t(), requires_grad=False)
+            layer.weight_scale = Parameter(weight_scale, requires_grad=False)
+            layer.logical_widths = None
+            layer.act_scale = None
             return
 
-        qweight, weight_scale = ops.scaled_fp8_quant(layer.weight)
-        # torch._scaled_mm requires column-major in the second
-        # input (weight), so we transpose the quantized weight.
-        layer.weight = Parameter(qweight.t(), requires_grad=False)
-        layer.weight_scaling_factor.data.copy_(weight_scale)
+        # If checkpoint is fp8, requantize the separately quantized logical
+        # weights into a single fp8 weight with a single weight scale.
+        else:
+            # WEIGHT_SCALE / WEIGHT
+            #   Loop over logical weights, requantizing with single scale.
+            max_w_scale = layer.weight_scale.max()
+            start = 0
+            for idx, logical_width in enumerate(layer.logical_widths):
+                end = start + logical_width
+                weight_dq = per_tensor_dequantize(layer.weight[start:end, :],
+                                                  layer.weight_scale[idx])
+
+                layer.weight[start:end, :] = per_tensor_quantize(
+                    weight_dq, layer.weight_scale.max())
+                start = end
+            layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
+
+            # WEIGHT
+            #   Transpose weight for passing to torch._scaled_mm
+            weight = layer.weight
+            layer.weight = Parameter(weight.t(), requires_grad=False)
+
+            # ACT_SCALE
+            #   Dynamic: set to None (required input to ops.scaled_fp8_quant).
+            #   Static:  set to max of the act_scales (since they are equal).
+            if self.quant_config.activation_scheme == "dynamic":
+                layer.act_scale = None
+            elif self.quant_config.activation_scheme == "static":
+                if not all_close_1d(layer.act_scale):
+                    raise ValueError(
+                        "All the act_scales for the logical weights of a layer "
+                        f"must be equal. But got {layer.act_scale}")
+                layer.act_scale = Parameter(layer.act_scale.max(),
+                                            requires_grad=False)
+            else:
+                raise ValueError(
+                    f"Unknown scheme {self.quant_config.activation_scheme}")
 
     def apply(self,
               layer: torch.nn.Module,
               x: torch.Tensor,
               bias: Optional[torch.Tensor] = None) -> torch.Tensor:
-        qinput, x_scale = ops.scaled_fp8_quant(x)
+        # ops.scaled_fp8_quant supports both dynamic and static quant.
+        #   If dynamic, layer.act_scale is None and x_scale computed from x.
+        #   If static,  layer.act_scale is scalar and x_scale set to act_scale.
+        qinput, x_scale = scaled_fp8_quant(x, layer.act_scale)
+
+        # Fused GEMM_DQ
         output, _ = torch._scaled_mm(
             qinput,
             layer.weight,
             out_dtype=x.dtype,
             scale_a=x_scale,
-            scale_b=layer.weight_scaling_factor,
+            scale_b=layer.weight_scale,
             bias=bias,
         )
+
         return output
+
+
+def all_close_1d(x: torch.Tensor) -> bool:
+    assert len(x.shape) == 1
+    return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
+
+
+def per_tensor_quantize(tensor: torch.Tensor,
+                        inv_scale: float) -> torch.Tensor:
+    finfo = torch.finfo(torch.float8_e4m3fn)
+    qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
+    return qweight.to(torch.float8_e4m3fn)
+
+
+def per_tensor_dequantize(tensor: torch.Tensor,
+                          inv_scale: float) -> torch.Tensor:
+    fake_qweight = tensor.to(torch.float16)
+    dq_weight = fake_qweight * inv_scale
+    return dq_weight