Browse Source

quants: add GPTQ and FBGEMM to AphroditeParameters (#987)

AlpinDale 2 months ago
parent
commit
83af2524f3

+ 25 - 14
aphrodite/modeling/layers/linear.py

@@ -17,17 +17,27 @@ from aphrodite.distributed import (divide,
                                    tensor_model_parallel_all_reduce)
 from aphrodite.modeling.parameter import (BaseAphroditeParameter,
                                           PackedAphroditeParameter,
-                                          PerTensorScaleParameter)
+                                          PackedColumnParameter,
+                                          PerTensorScaleParameter,
+                                          RowAphroditeParameter)
 # yapf: enable
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import (QuantizationConfig,
                                                 QuantizeMethodBase)
 
 WEIGHT_LOADER_V2_SUPPORTED = [
-    "CompressedTensorsLinearMethod", "GPTQMarlinLinearMethod",
-    "AWQMarlinLinearMethod", "AWQLinearMethod", "HQQMarlinMethod",
-    "Fp8LinearMethod", "MarlinLinearMethod", "QQQLinearMethod",
-    "GPTQMarlin24LinearMethod", "TPUInt8LinearMethod"
+    "AWQLinearMethod",
+    "AWQMarlinLinearMethod",
+    "CompressedTensorsLinearMethod",
+    "FBGEMMFp8LinearMethod",
+    "Fp8LinearMethod",
+    "GPTQLinearMethod",
+    "GPTQMarlin24LinearMethod",
+    "GPTQMarlinLinearMethod",
+    "HQQMarlinMethod",
+    "MarlinLinearMethod",
+    "QQQLinearMethod",
+    "TPUInt8LinearMethod",
 ]
 
 
@@ -601,8 +611,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
             # Special case for Quantization.
             # If quantized, we need to adjust the offset and size to account
             # for the packing.
-            if isinstance(param, PackedAphroditeParameter
-                          ) and param.packed_dim == param.output_dim:
+            if isinstance(param, (PackedColumnParameter,
+                                  PackedAphroditeParameter,
+                                  )) and param.packed_dim == param.output_dim:
                 shard_size, shard_offset = \
                     param.adjust_shard_indexes_for_packing(
                     shard_size=shard_size, shard_offset=shard_offset)
@@ -621,7 +632,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
                 param.load_merged_column_weight(loaded_weight=loaded_weight,
                                                 shard_id=0)
                 return
-            elif type(param) is BaseAphroditeParameter:
+            elif type(param) in (RowAphroditeParameter, BaseAphroditeParameter):
                 param.load_merged_column_weight(loaded_weight=loaded_weight)
                 return
             self._load_fused_module_from_checkpoint(param, loaded_weight)
@@ -760,8 +771,9 @@ class QKVParallelLinear(ColumnParallelLinear):
             # Special case for Quantization.
             # If quantized, we need to adjust the offset and size to account
             # for the packing.
-            if isinstance(param, PackedAphroditeParameter
-                          ) and param.packed_dim == param.output_dim:
+            if isinstance(param, (PackedColumnParameter,
+                                  PackedAphroditeParameter,
+                                  )) and param.packed_dim == param.output_dim:
                 shard_size, shard_offset = \
                     param.adjust_shard_indexes_for_packing(
                     shard_size=shard_size, shard_offset=shard_offset)
@@ -777,11 +789,10 @@ class QKVParallelLinear(ColumnParallelLinear):
                          loaded_shard_id: Optional[str] = None):
         if loaded_shard_id is None:  # special case for certain models
             if isinstance(param, PerTensorScaleParameter):
-                param.load_merged_column_weight(loaded_weight=loaded_weight,
-                                                shard_id=0)
+                param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
                 return
-            elif type(param) is BaseAphroditeParameter:
-                param.load_merged_column_weight(loaded_weight=loaded_weight)
+            elif type(param) in (RowAphroditeParameter, BaseAphroditeParameter):
+                param.load_qkv_weight(loaded_weight=loaded_weight)
                 return
             self._load_fused_module_from_checkpoint(param, loaded_weight)
             return

+ 6 - 3
aphrodite/modeling/layers/vocab_parallel_embedding.py

@@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter
 from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
+from aphrodite.modeling.parameter import BaseAphroditeParameter
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import (
     QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
@@ -368,10 +369,12 @@ class VocabParallelEmbedding(torch.nn.Module):
         # If param packed on the same dim we are sharding on, then
         # need to adjust offsets of loaded weight by pack_factor.
         if packed_dim is not None and packed_dim == output_dim:
+            packed_factor = param.packed_factor if isinstance(
+                param, BaseAphroditeParameter) else param.pack_factor
             assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
-                                                       param.pack_factor)
-            start_idx = start_idx // param.pack_factor
-            shard_size = shard_size // param.pack_factor
+                                                       param.packed_factor)
+            start_idx = start_idx // packed_factor
+            shard_size = shard_size // packed_factor
         else:
             assert loaded_weight.shape[output_dim] == self.org_vocab_size
 

+ 3 - 2
aphrodite/modeling/parameter.py

@@ -1,3 +1,4 @@
+from fractions import Fraction
 from typing import Callable, List, Optional, Tuple, Union
 
 import torch
@@ -257,7 +258,7 @@ class PackedColumnParameter(_ColumnAphroditeParameter):
     """
 
     def __init__(self,
-                 packed_factor: int,
+                 packed_factor: Union[int, Fraction],
                  packed_dim: int,
                  marlin_tile_size: Optional[int] = None,
                  **kwargs):
@@ -298,7 +299,7 @@ class PackedAphroditeParameter(ModelWeightParameter):
     """
 
     def __init__(self,
-                 packed_factor: int,
+                 packed_factor: Union[int, Fraction],
                  packed_dim: int,
                  marlin_tile_size: Optional[int] = None,
                  **kwargs):

+ 20 - 14
aphrodite/quantization/fbgemm_fp8.py

@@ -6,7 +6,8 @@ from torch.nn.parameter import Parameter
 
 from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
                                               UnquantizedLinearMethod)
-from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.modeling.parameter import (ChannelQuantScaleParameter,
+                                          ModelWeightParameter)
 from aphrodite.platforms import current_platform
 from aphrodite.quantization.base_config import (QuantizationConfig,
                                                 QuantizeMethodBase)
@@ -14,8 +15,7 @@ from aphrodite.quantization.fp8 import cutlass_fp8_supported
 from aphrodite.quantization.utils.marlin_utils_fp8 import (
     apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
 from aphrodite.quantization.utils.quant_utils import is_layer_skipped
-from aphrodite.quantization.utils.w8a8_utils import (
-    apply_fp8_linear, create_per_channel_scale_param)
+from aphrodite.quantization.utils.w8a8_utils import apply_fp8_linear
 
 
 class FBGEMMFp8Config(QuantizationConfig):
@@ -81,6 +81,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
         params_dtype: torch.dtype,
         **extra_weight_attrs,
     ):
+        weight_loader = extra_weight_attrs.get("weight_loader")
         del input_size, output_size
         output_size_per_partition = sum(output_partition_sizes)
 
@@ -91,20 +92,21 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
         layer.orig_dtype = params_dtype
 
         # WEIGHT
-        weight = Parameter(torch.empty(output_size_per_partition,
-                                       input_size_per_partition,
-                                       dtype=torch.float8_e4m3fn),
-                           requires_grad=False)
+        weight = ModelWeightParameter(data=torch.empty(
+            output_size_per_partition,
+            input_size_per_partition,
+            dtype=torch.float8_e4m3fn),
+                                      input_dim=1,
+                                      output_dim=0,
+                                      weight_loader=weight_loader)
         layer.register_parameter("weight", weight)
-        set_weight_attrs(weight, {
-            "input_dim": 1,
-            "output_dim": 0,
-            **extra_weight_attrs,
-        })
 
         # WEIGHT SCALE
-        weight_scale = create_per_channel_scale_param(output_partition_sizes,
-                                                      **extra_weight_attrs)
+        weight_scale = ChannelQuantScaleParameter(data=torch.empty(
+            (sum(output_partition_sizes), 1), dtype=torch.float32),
+                                                  output_dim=0,
+                                                  weight_loader=weight_loader)
+        weight_scale[:] = torch.finfo(torch.float32).min
         layer.register_parameter("weight_scale", weight_scale)
 
         # INPUT SCALE UPPER BOUND
@@ -114,6 +116,10 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
         layer.input_scale_ub = input_scale_ub
 
     def process_weights_after_loading(self, layer: Module) -> None:
+        # required by torch.compile
+        layer.weight_scale = Parameter(layer.weight_scale.data,
+                                       requires_grad=False)
+        layer.weight = Parameter(layer.weight.data, requires_grad=False)
         weight = layer.weight
         layer.weight = Parameter(weight.t(), requires_grad=False)
 

+ 55 - 45
aphrodite/quantization/gptq.py

@@ -9,7 +9,11 @@ from torch.nn.parameter import Parameter
 from aphrodite import _custom_ops as ops
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
 from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
-from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.modeling.parameter import (ChannelQuantScaleParameter,
+                                          GroupQuantScaleParameter,
+                                          PackedAphroditeParameter,
+                                          PackedColumnParameter,
+                                          RowAphroditeParameter)
 from aphrodite.quantization.base_config import QuantizationConfig
 
 
@@ -107,6 +111,7 @@ class GPTQLinearMethod(LinearMethodBase):
         **extra_weight_attrs,
     ):
         del output_size  # Unused.
+        weight_loader = extra_weight_attrs.get("weight_loader")
         if input_size_per_partition % self.quant_config.group_size != 0:
             raise ValueError(
                 "The input size is not aligned with the quantized "
@@ -137,73 +142,78 @@ class GPTQLinearMethod(LinearMethodBase):
                 scale_and_zero_size = input_size_per_partition // group_size
                 scale_and_zero_input_dim = 0
 
-        qweight = Parameter(
-            torch.empty(
+        qweight = PackedAphroditeParameter(
+            data=torch.empty(
                 input_size_per_partition // self.quant_config.pack_factor,
                 output_size_per_partition,
                 dtype=torch.int32,
             ),
-            requires_grad=False,
-        )
-        set_weight_attrs(
-            qweight, {
-                "input_dim": 0,
-                "output_dim": 1,
-                "packed_dim": 0,
-                "pack_factor": self.quant_config.pack_factor,
-            })
-        g_idx = Parameter(
-            torch.tensor(
-                [
-                    i // self.quant_config.group_size
-                    for i in range(input_size_per_partition)
-                ],
-                dtype=torch.int32,
-            ),
-            requires_grad=False,
-        )
-        # Ignore warning from fused linear layers such as QKVParallelLinear.
-        set_weight_attrs(g_idx, {"input_dim": 0, "ignore_warning": True})
-        qzeros = Parameter(
+            input_dim=0,
+            output_dim=1,
+            packed_dim=0,
+            packed_factor=self.quant_config.pack_factor,
+            weight_loader=weight_loader)
+        g_idx = RowAphroditeParameter(data=torch.tensor(
+            [
+                i // self.quant_config.group_size
+                for i in range(input_size_per_partition)
+            ],
+            dtype=torch.int32,
+        ),
+                                 input_dim=0,
+                                 weight_loader=weight_loader)
+        qzeros_args = {
+            "data":
             torch.empty(
                 scale_and_zero_size,
                 output_size_per_partition // self.quant_config.pack_factor,
                 dtype=torch.int32,
             ),
-            requires_grad=False,
-        )
-        set_weight_attrs(
-            qzeros, {
-                "input_dim": scale_and_zero_input_dim,
-                "output_dim": 1,
-                "packed_dim": 1,
-                "pack_factor": self.quant_config.pack_factor,
-            })
-        scales = Parameter(
+            "weight_loader":
+            weight_loader
+        }
+        weight_scale_args = {
+            "data":
             torch.empty(
                 scale_and_zero_size,
                 output_size_per_partition,
                 dtype=params_dtype,
             ),
-            requires_grad=False,
-        )
-        set_weight_attrs(scales, {
-            "input_dim": scale_and_zero_input_dim,
-            "output_dim": 1,
-        })
+            "weight_loader":
+            weight_loader
+        }
+        if scale_and_zero_input_dim is None:
+            scales = ChannelQuantScaleParameter(output_dim=1,
+                                                **weight_scale_args)
+            qzeros = PackedColumnParameter(
+                output_dim=1,
+                packed_dim=1,
+                packed_factor=self.quant_config.pack_factor,
+                **qzeros_args)
+        else:
+            scales = GroupQuantScaleParameter(output_dim=1,
+                                              input_dim=0,
+                                              **weight_scale_args)
+            qzeros = PackedColumnParameter(
+                input_dim=0,
+                output_dim=1,
+                packed_dim=1,
+                packed_factor=self.quant_config.pack_factor,
+                **qzeros_args)
 
         layer.register_parameter("qweight", qweight)
-        set_weight_attrs(qweight, extra_weight_attrs)
         layer.register_parameter("g_idx", g_idx)
-        set_weight_attrs(g_idx, extra_weight_attrs)
         layer.register_parameter("qzeros", qzeros)
-        set_weight_attrs(qzeros, extra_weight_attrs)
         layer.register_parameter("scales", scales)
-        set_weight_attrs(scales, extra_weight_attrs)
 
         layer.exllama_state = exllama_state
 
     def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+        # for torch.compile
+        layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
+        layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
+        layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
+        layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
         # exllama needs to shuffle the weight after the weight is loaded
         # here we do the shuffle on first forward pass
         if layer.exllama_state == ExllamaState.UNINITIALIZED:

+ 0 - 27
aphrodite/quantization/utils/w8a8_utils.py

@@ -1,11 +1,9 @@
 from typing import List, Optional, Tuple, Union
 
 import torch
-from torch.nn import Parameter
 
 from aphrodite import _custom_ops as ops
 from aphrodite.common.utils import is_hip
-from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.platforms import current_platform
 
 # scaled_mm in pytorch on rocm has a bug that requires always
@@ -37,31 +35,6 @@ def all_close_1d(x: torch.Tensor) -> bool:
     return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
 
 
-def create_per_tensor_scale_param(
-    output_partition_sizes: List[int],
-    **extra_weight_attrs,
-) -> Parameter:
-    scale = Parameter(torch.empty(len(output_partition_sizes),
-                                  dtype=torch.float32),
-                      requires_grad=False)
-    scale[:] = torch.finfo(torch.float32).min
-    set_weight_attrs(scale, {
-        "needs_scalar_to_array": True,
-        **extra_weight_attrs
-    })
-    return scale
-
-
-def create_per_channel_scale_param(output_partition_sizes: List[int],
-                                   **extra_weight_attrs) -> Parameter:
-    scale = Parameter(torch.empty((sum(output_partition_sizes), 1),
-                                  dtype=torch.float32),
-                      requires_grad=False)
-    scale[:] = torch.finfo(torch.float32).min
-    set_weight_attrs(scale, {"output_dim": 0, **extra_weight_attrs})
-    return scale
-
-
 def convert_to_channelwise(
         weight_scale: torch.Tensor,
         logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: