Преглед изворни кода

feat: migrate awq and awq_marlin to AphroditeParameter (#702)

AlpinDale пре 6 месеци
родитељ
комит
edec2e9a9e
3 измењених фајлова са 70 додато и 81 уклоњено
  1. 2 1
      aphrodite/modeling/layers/linear.py
  2. 35 39
      aphrodite/quantization/awq.py
  3. 33 41
      aphrodite/quantization/awq_marlin.py

+ 2 - 1
aphrodite/modeling/layers/linear.py

@@ -24,7 +24,8 @@ from aphrodite.quantization.base_config import (QuantizationConfig,
                                                 QuantizeMethodBase)
 
 WEIGHT_LOADER_V2_SUPPORTED = [
-    "CompressedTensorsLinearMethod", "GPTQMarlinLinearMethod"
+    "CompressedTensorsLinearMethod", "GPTQMarlinLinearMethod",
+    "AWQMarlinLinearMethod", "AWQLinearMethod",
 ]
 
 

+ 35 - 39
aphrodite/quantization/awq.py

@@ -1,11 +1,11 @@
 from typing import Any, Dict, List, Optional
 
 import torch
-from torch.nn.parameter import Parameter
 
 from aphrodite import _custom_ops as ops
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
-from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.modeling.parameter import (GroupQuantScaleParameter,
+                                          PackedAphroditeParameter)
 from aphrodite.quantization.base_config import QuantizationConfig
 
 
@@ -98,55 +98,51 @@ class AWQLinearMethod(LinearMethodBase):
                 "weight shape. This can be caused by too large "
                 "tensor parallel size.")
 
-        qweight = Parameter(
-            torch.empty(
+        weight_loader = extra_weight_attrs.get("weight_loader")
+        qweight = PackedAphroditeParameter(
+            data=torch.empty(
                 input_size_per_partition,
                 output_size_per_partition // self.quant_config.pack_factor,
                 dtype=torch.int32,
             ),
-            requires_grad=False,
-        )
-        set_weight_attrs(
-            qweight, {
-                "input_dim": 0,
-                "output_dim": 1,
-                "packed_dim": 1,
-                "pack_factor": self.quant_config.pack_factor,
-            })
-        qzeros = Parameter(
-            torch.empty(
+            input_dim=0,
+            output_dim=1,
+            packed_dim=1,
+            packed_factor=self.quant_config.pack_factor,
+            weight_loader=weight_loader)
+
+        qzeros = PackedAphroditeParameter(
+            data=torch.empty(
                 input_size_per_partition // self.quant_config.group_size,
                 output_size_per_partition // self.quant_config.pack_factor,
                 dtype=torch.int32,
             ),
-            requires_grad=False,
-        )
-        set_weight_attrs(
-            qzeros, {
-                "input_dim": 0,
-                "output_dim": 1,
-                "packed_dim": 1,
-                "pack_factor": self.quant_config.pack_factor,
-            })
-        scales = Parameter(
-            torch.empty(
-                input_size_per_partition // self.quant_config.group_size,
-                output_size_per_partition,
-                dtype=params_dtype,
-            ),
-            requires_grad=False,
-        )
-        set_weight_attrs(scales, {
-            "input_dim": 0,
-            "output_dim": 1,
-        })
+            input_dim=0,
+            output_dim=1,
+            packed_dim=1,
+            packed_factor=self.quant_config.pack_factor,
+            weight_loader=weight_loader)
+
+        scales = GroupQuantScaleParameter(data=torch.empty(
+            input_size_per_partition // self.quant_config.group_size,
+            output_size_per_partition,
+            dtype=params_dtype,
+        ),
+                                          input_dim=0,
+                                          output_dim=1,
+                                          weight_loader=weight_loader)
 
         layer.register_parameter("qweight", qweight)
-        set_weight_attrs(qweight, extra_weight_attrs)
         layer.register_parameter("qzeros", qzeros)
-        set_weight_attrs(qzeros, extra_weight_attrs)
         layer.register_parameter("scales", scales)
-        set_weight_attrs(scales, extra_weight_attrs)
+
+    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+        layer.qweight = torch.nn.Parameter(layer.qweight.data,
+                                           requires_grad=False)
+        layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
+                                          requires_grad=False)
+        layer.scales = torch.nn.Parameter(layer.scales.data,
+                                          requires_grad=False)
 
     def apply(self,
               layer: torch.nn.Module,

+ 33 - 41
aphrodite/quantization/awq_marlin.py

@@ -2,12 +2,12 @@ from typing import Any, Dict, List, Optional
 
 import torch
 from loguru import logger
-from torch.nn.parameter import Parameter
 
 from aphrodite import _custom_ops as ops
-from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
-                                              set_weight_attrs)
+from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
 from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
+from aphrodite.modeling.parameter import (GroupQuantScaleParameter,
+                                          PackedAphroditeParameter)
 from aphrodite.quantization.base_config import QuantizationConfig
 from aphrodite.quantization.utils.marlin_utils import (
     apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
@@ -147,6 +147,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
     ) -> None:
         del output_size
         output_size_per_partition = sum(output_partition_sizes)
+        weight_loader = extra_weight_attrs.get("weight_loader")
 
         # Normalize group_size
         if self.quant_config.group_size != -1:
@@ -160,59 +161,44 @@ class AWQMarlinLinearMethod(LinearMethodBase):
             input_size=input_size,
             group_size=group_size)
 
-        qweight = Parameter(
-            torch.empty(
+        qweight = PackedAphroditeParameter(
+            data=torch.empty(
                 input_size_per_partition,
                 output_size_per_partition // self.quant_config.pack_factor,
                 dtype=torch.int32,
             ),
-            requires_grad=False,
-        )
-        set_weight_attrs(
-            qweight, {
-                "input_dim": 0,
-                "output_dim": 1,
-                "packed_dim": 1,
-                "pack_factor": self.quant_config.pack_factor,
-            })
+            input_dim=0,
+            output_dim=1,
+            packed_dim=1,
+            packed_factor=self.quant_config.pack_factor,
+            weight_loader=weight_loader)
 
         num_groups = input_size_per_partition // group_size
 
-        qzeros = Parameter(
-            torch.empty(
+        qzeros = PackedAphroditeParameter(
+            data=torch.empty(
                 num_groups,
                 output_size_per_partition // self.quant_config.pack_factor,
                 dtype=torch.int32,
             ),
-            requires_grad=False,
-        )
-        set_weight_attrs(
-            qzeros, {
-                "input_dim": 0,
-                "output_dim": 1,
-                "packed_dim": 1,
-                "pack_factor": self.quant_config.pack_factor,
-            })
-
-        scales = Parameter(
-            torch.empty(
-                num_groups,
-                output_size_per_partition,
-                dtype=params_dtype,
-            ),
-            requires_grad=False,
-        )
-        set_weight_attrs(scales, {
-            "input_dim": 0,
-            "output_dim": 1,
-        })
+            input_dim=0,
+            output_dim=1,
+            packed_dim=1,
+            packed_factor=self.quant_config.pack_factor,
+            weight_loader=weight_loader)
+
+        scales = GroupQuantScaleParameter(data=torch.empty(
+            num_groups,
+            output_size_per_partition,
+            dtype=params_dtype,
+        ),
+                                          input_dim=0,
+                                          output_dim=1,
+                                          weight_loader=weight_loader)
 
         layer.register_parameter("qweight", qweight)
-        set_weight_attrs(qweight, extra_weight_attrs)
         layer.register_parameter("qzeros", qzeros)
-        set_weight_attrs(qzeros, extra_weight_attrs)
         layer.register_parameter("scales", scales)
-        set_weight_attrs(scales, extra_weight_attrs)
 
         layer.input_size_per_partition = input_size_per_partition
         layer.output_size_per_partition = output_size_per_partition
@@ -224,6 +210,12 @@ class AWQMarlinLinearMethod(LinearMethodBase):
     # Here, we handle the repacking
     def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
         device = layer.qweight.device
+        layer.qweight = torch.nn.Parameter(layer.qweight.data,
+                                           requires_grad=False)
+        layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
+                                          requires_grad=False)
+        layer.scales = torch.nn.Parameter(layer.scales.data,
+                                          requires_grad=False)
 
         # Allocate marlin workspace
         layer.workspace = marlin_make_workspace(