Browse Source

chore: add AphroditeParameter support for FP8 quant (#902)

AlpinDale 2 months ago
parent
commit
afc9a28aa0
3 changed files with 49 additions and 16 deletions
  1. 15 0
      aphrodite/modeling/layers/linear.py
  2. 8 1
      aphrodite/modeling/parameter.py
  3. 26 15
      aphrodite/quantization/fp8.py

+ 15 - 0
aphrodite/modeling/layers/linear.py

@@ -26,6 +26,7 @@ from aphrodite.quantization.base_config import (QuantizationConfig,
 WEIGHT_LOADER_V2_SUPPORTED = [
     "CompressedTensorsLinearMethod", "GPTQMarlinLinearMethod",
     "AWQMarlinLinearMethod", "AWQLinearMethod", "HQQMarlinMethod",
+    "Fp8LinearMethod",
 ]
 
 
@@ -359,6 +360,12 @@ class ColumnParallelLinear(LinearBase):
         param_data.copy_(loaded_weight)
 
     def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
+        # Special case for loading scales off disk, which often do not
+        # have a shape (such as in the case of AutoFP8).
+        if len(loaded_weight.shape) == 0:
+            assert loaded_weight.numel() == 1
+            loaded_weight = loaded_weight.reshape(1)
+
         param.load_column_parallel_weight(loaded_weight=loaded_weight)
 
     def forward(self, input_):
@@ -1081,8 +1088,16 @@ class RowParallelLinear(LinearBase):
         assert param_data.shape == loaded_weight.shape
         param_data.copy_(loaded_weight)
 
+
     def weight_loader_v2(self, param: BaseAphroditeParameter,
                          loaded_weight: torch.Tensor):
+
+        # Special case for loading scales off disk, which often do not
+        # have a shape (such as in the case of AutoFP8).
+        if len(loaded_weight.shape) == 0:
+            assert loaded_weight.numel() == 1
+            loaded_weight = loaded_weight.reshape(1)
+
         param.load_row_parallel_weight(loaded_weight=loaded_weight)
 
     def forward(self, input_):

+ 8 - 1
aphrodite/modeling/parameter.py

@@ -208,10 +208,17 @@ class PerTensorScaleParameter(BaseAphroditeParameter):
         if isinstance(shard_id, int):
             return shard_id
 
+        # if not int, assume shard_id for qkv
+        # map to int and return
         assert isinstance(shard_id, str)
         assert shard_id in self.qkv_idxs
         return self.qkv_idxs[shard_id]
 
+    # For row parallel layers, no sharding needed
+    # load weight into parameter as is
+    def load_row_parallel_weight(self, *args, **kwargs):
+        super().load_row_parallel_weight(*args, **kwargs)
+
     def load_merged_column_weight(self, *args, **kwargs):
         self._load_into_shard_id(*args, **kwargs)
 
@@ -219,7 +226,7 @@ class PerTensorScaleParameter(BaseAphroditeParameter):
         self._load_into_shard_id(*args, **kwargs)
 
     def load_column_parallel_weight(self, *args, **kwargs):
-        self._load_into_shard_id(*args, **kwargs)
+        super().load_row_parallel_weight(*args, **kwargs)
 
     def _load_into_shard_id(self, loaded_weight: torch.Tensor,
                             shard_id: Union[str, int], **kwargs):

+ 26 - 15
aphrodite/quantization/fp8.py

@@ -11,6 +11,8 @@ from aphrodite.common.utils import is_hip, print_warning_once
 from aphrodite.modeling.layers.fused_moe import FusedMoE, FusedMoEMethodBase
 from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
+from aphrodite.modeling.parameter import (ModelWeightParameter,
+                                          PerTensorScaleParameter)
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.platforms import current_platform
 from aphrodite.quantization.base_config import (QuantizationConfig,
@@ -21,8 +23,7 @@ from aphrodite.quantization.utils.marlin_utils_fp8 import (
 from aphrodite.quantization.utils.quant_utils import is_layer_skipped
 from aphrodite.quantization.utils.w8a8_utils import (
     all_close_1d, apply_fp8_linear, convert_to_channelwise,
-    create_per_tensor_scale_param, cutlass_fp8_supported,
-    normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
+    cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
     requantize_with_max_scale)
 
 ACTIVATION_SCHEMES = ["static", "dynamic"]
@@ -136,6 +137,7 @@ class Fp8LinearMethod(LinearMethodBase):
     ):
         del input_size, output_size
         output_size_per_partition = sum(output_partition_sizes)
+        weight_loader = extra_weight_attrs.get("weight_loader")
 
         layer.logical_widths = output_partition_sizes
 
@@ -147,34 +149,38 @@ class Fp8LinearMethod(LinearMethodBase):
         weight_dtype = (torch.float8_e4m3fn
                         if self.quant_config.is_checkpoint_fp8_serialized else
                         params_dtype)
-        weight = Parameter(torch.empty(output_size_per_partition,
-                                       input_size_per_partition,
-                                       dtype=weight_dtype),
-                           requires_grad=False)
+        weight = ModelWeightParameter(data=torch.empty(
+            output_size_per_partition,
+            input_size_per_partition,
+            dtype=weight_dtype),
+                                      input_dim=1,
+                                      output_dim=0,
+                                      weight_loader=weight_loader)
         layer.register_parameter("weight", weight)
-        set_weight_attrs(weight, {
-            **extra_weight_attrs,
-            "input_dim": 1,
-            "output_dim": 0,
-        })
 
         # If checkpoint is serialized fp8, load them.
         # Otherwise, wait until process_weights_after_loading.
         if self.quant_config.is_checkpoint_fp8_serialized:
             # WEIGHT SCALE
-            scale = create_per_tensor_scale_param(output_partition_sizes,
-                                                  **extra_weight_attrs)
+            scale = PerTensorScaleParameter(data=torch.empty(
+                len(output_partition_sizes), dtype=torch.float32),
+                                            weight_loader=weight_loader)
+            scale[:] = torch.finfo(torch.float32).min
             layer.register_parameter("weight_scale", scale)
 
             # INPUT ACTIVATION SCALE
             if self.quant_config.activation_scheme == "static":
-                scale = create_per_tensor_scale_param(output_partition_sizes,
-                                                      **extra_weight_attrs)
+                scale = PerTensorScaleParameter(data=torch.empty(
+                    len(output_partition_sizes), dtype=torch.float32),
+                                                weight_loader=weight_loader)
+                scale[:] = torch.finfo(torch.float32).min
                 layer.register_parameter("input_scale", scale)
             else:
                 layer.register_parameter("input_scale", None)
 
     def process_weights_after_loading(self, layer: Module) -> None:
+        layer.weight = torch.nn.Parameter(layer.weight.data,
+                                          requires_grad=False)
         # If checkpoint not serialized fp8, quantize the weights.
         if not self.quant_config.is_checkpoint_fp8_serialized:
             qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
@@ -196,6 +202,11 @@ class Fp8LinearMethod(LinearMethodBase):
         # If checkpoint is fp8, handle that there are N scales for N
         # shards in a fused module
         else:
+            layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
+                                                    requires_grad=False)
+            if self.quant_config.activation_scheme == "static":
+                layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
+                                                       requires_grad=False)
             # If using marlin (w8a16), kernel uses channelwise weights,
             # so extend the weight scales to be channelwise.
             if self.use_marlin: