Browse Source

quantization: update marlin to use `AphroditeParameters` (#913)

AlpinDale 2 months ago
parent
commit
799667737b
2 changed files with 102 additions and 71 deletions
  1. 1 1
      aphrodite/modeling/layers/linear.py
  2. 101 70
      aphrodite/quantization/marlin.py

+ 1 - 1
aphrodite/modeling/layers/linear.py

@@ -26,7 +26,7 @@ from aphrodite.quantization.base_config import (QuantizationConfig,
 WEIGHT_LOADER_V2_SUPPORTED = [
     "CompressedTensorsLinearMethod", "GPTQMarlinLinearMethod",
     "AWQMarlinLinearMethod", "AWQLinearMethod", "HQQMarlinMethod",
-    "Fp8LinearMethod",
+    "Fp8LinearMethod", "MarlinLinearMethod"
 ]
 
 

+ 101 - 70
aphrodite/quantization/marlin.py

@@ -7,7 +7,10 @@ from torch.nn.parameter import Parameter
 from aphrodite import _custom_ops as ops
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
 from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
-from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.modeling.parameter import (BaseAphroditeParameter,
+                                          ChannelQuantScaleParameter,
+                                          GroupQuantScaleParameter,
+                                          PackedAphroditeParameter)
 from aphrodite.quantization.base_config import QuantizationConfig
 
 
@@ -29,7 +32,8 @@ class MarlinConfig(QuantizationConfig):
             raise ValueError(
                 "Currently, only group size 128 and -1 (channelwise) "
                 "is supported for Marlin, but got group_size of "
-                f"{self.group_size}")
+                f"{self.group_size}"
+            )
 
         # 4 Bits packed into 32 bit datatype.
         self.pack_factor = 32 // 4
@@ -51,8 +55,10 @@ class MarlinConfig(QuantizationConfig):
         self.perm_len = 1024
 
     def __repr__(self) -> str:
-        return (f"MarlinConfig(group_size={self.group_size}, "
-                f"lm_head_quantized={self.lm_head_quantized})")
+        return (
+            f"MarlinConfig(group_size={self.group_size}, "
+            f"lm_head_quantized={self.lm_head_quantized})"
+        )
 
     @classmethod
     def get_name(cls) -> str:
@@ -74,33 +80,42 @@ class MarlinConfig(QuantizationConfig):
     @classmethod
     def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
         group_size = cls.get_from_keys(config, ["group_size"])
-        lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
-                                                 default=False)
+        lm_head_quantized = cls.get_from_keys_or(
+            config, ["lm_head"], default=False
+        )
         return cls(group_size, lm_head_quantized)
 
     @classmethod
-    def override_quantization_method(cls, hf_quant_cfg,
-                                     user_quant) -> Optional[str]:
+    def override_quantization_method(
+        cls, hf_quant_cfg, user_quant
+    ) -> Optional[str]:
         # compat: autogptq >=0.8.0 use checkpoint_format: str
         # compat: autogptq <=0.7.1 is_marlin_format: bool
-        is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin"
-                            or hf_quant_cfg.get("is_marlin_format", False))
+        is_marlin_format = hf_quant_cfg.get(
+            "checkpoint_format"
+        ) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
 
-        is_valid_user_quant = (user_quant is None or user_quant == "gptq"
-                               or user_quant == "marlin")
+        is_valid_user_quant = (
+            user_quant is None or user_quant == "gptq" or user_quant == "marlin"
+        )
 
         if is_marlin_format and is_valid_user_quant:
-            msg = ("The model is serialized in {} format. Using {} kernel.".
-                   format(cls.get_name(), cls.get_name()))
+            msg = (
+                "The model is serialized in {} format. Using {} kernel.".format(
+                    cls.get_name(), cls.get_name()
+                )
+            )
             logger.info(msg)
             return cls.get_name()
 
         return None
 
-    def get_quant_method(self, layer: torch.nn.Module,
-                         prefix: str) -> Optional["MarlinLinearMethod"]:
-        if (isinstance(layer, LinearBase) or
-            (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
+    def get_quant_method(
+        self, layer: torch.nn.Module, prefix: str
+    ) -> Optional["MarlinLinearMethod"]:
+        if isinstance(layer, LinearBase) or (
+            isinstance(layer, ParallelLMHead) and self.lm_head_quantized
+        ):
             return MarlinLinearMethod(self)
         return None
 
@@ -129,10 +144,12 @@ class MarlinLinearMethod(LinearMethodBase):
         **extra_weight_attrs,
     ):
         del output_size  # Unused.
+        weight_loader = extra_weight_attrs["weight_loader"]
 
         if params_dtype != torch.float16:
             raise ValueError(
-                f"The params dtype must be float16, but got {params_dtype}")
+                f"The params dtype must be float16, but got {params_dtype}"
+            )
 
         # Validate output_size_per_partition
         output_size_per_partition = sum(output_partition_sizes)
@@ -140,91 +157,104 @@ class MarlinLinearMethod(LinearMethodBase):
             raise ValueError(
                 f"Weight output_size_per_partition = "
                 f"{output_size_per_partition} is not divisible by "
-                f"min_n_threads = {self.quant_config.min_n_threads}.")
+                f"min_n_threads = {self.quant_config.min_n_threads}."
+            )
         if output_size_per_partition % self.quant_config.pack_factor != 0:
             raise ValueError(
                 f"Weight output_size_per_partition = "
                 f"{output_size_per_partition} is not divisible by "
-                f"pack_factor = {self.quant_config.pack_factor}.")
+                f"pack_factor = {self.quant_config.pack_factor}."
+            )
 
         # Validate input_size_per_partition
         if input_size_per_partition % self.quant_config.min_k_threads != 0:
             raise ValueError(
                 f"Weight input_size_per_partition = "
                 f"{input_size_per_partition} is not divisible by "
-                f"min_k_threads = {self.quant_config.min_k_threads}.")
-        if (self.quant_config.group_size != -1 and
-                input_size_per_partition % self.quant_config.group_size != 0):
-            raise ValueError(f"Weight input_size_per_partition = "
-                             f"{input_size_per_partition} is not divisible by "
-                             f"group_size = {self.quant_config.group_size}.")
+                f"min_k_threads = {self.quant_config.min_k_threads}."
+            )
+        if (
+            self.quant_config.group_size != -1
+            and input_size_per_partition % self.quant_config.group_size != 0
+        ):
+            raise ValueError(
+                f"Weight input_size_per_partition = "
+                f"{input_size_per_partition} is not divisible by "
+                f"group_size = {self.quant_config.group_size}."
+            )
 
         # Check that we have at least 4 tiles horizontally in the shard
         num_tiles_per_perm = self.quant_config.perm_len // (
-            self.quant_config.tile_size**2)
+            self.quant_config.tile_size**2
+        )
         if output_size_per_partition % num_tiles_per_perm != 0:
             raise ValueError(
-                "Each permutation group must reside on the same gpu")
+                "Each permutation group must reside on the same gpu"
+            )
 
         # Quantized 4Bit weights packed into Int32.
-        qweight = Parameter(
-            torch.empty(
+        qweight = PackedAphroditeParameter(
+            data=torch.empty(
                 input_size_per_partition // self.quant_config.tile_size,
-                output_size_per_partition * self.quant_config.tile_size //
-                self.quant_config.pack_factor,
+                output_size_per_partition
+                * self.quant_config.tile_size
+                // self.quant_config.pack_factor,
                 device="cuda",
                 dtype=torch.int32,
             ),
-            requires_grad=False,
-        )
-        set_weight_attrs(
-            qweight,
-            {
-                "input_dim": 0,
-                "output_dim": 1,
-                "packed_dim": 1,
-                "pack_factor": self.quant_config.pack_factor,
-                "marlin_tile_size": self.quant_config.tile_size,
-            },
+            input_dim=0,
+            output_dim=1,
+            packed_dim=1,
+            packed_factor=self.quant_config.pack_factor,
+            marlin_tile_size=self.quant_config.tile_size,
+            weight_loader=weight_loader,
         )
 
         # Determine if channelwise or not
-        input_groups = (1 if self.quant_config.group_size == -1 else
-                        input_size_per_partition //
-                        self.quant_config.group_size)
+        input_groups = (
+            1
+            if self.quant_config.group_size == -1
+            else input_size_per_partition // self.quant_config.group_size
+        )
 
-        scales = Parameter(
-            torch.empty(
+        weight_scale_args = {
+            "data": torch.empty(
                 input_groups,
                 output_size_per_partition,
                 device="cuda",
                 dtype=params_dtype,
             ),
-            requires_grad=False,
-        )
-        set_weight_attrs(
-            scales,
-            {
-                "input_dim": None if input_groups == 1 else 0,
-                "output_dim": 1,
-            },
-        )
+            "weight_loader": weight_loader,
+        }
+        if input_groups == 1:
+            scales = ChannelQuantScaleParameter(
+                output_dim=1, **weight_scale_args
+            )
+        else:
+            scales = GroupQuantScaleParameter(
+                output_dim=1, input_dim=0, **weight_scale_args
+            )
 
         # Allocate workspace (Used for internal locking mechanism)
         max_workspace_size = (
-            output_size_per_partition //
-            self.quant_config.min_n_threads) * self.quant_config.max_parallel
-        workspace = Parameter(torch.zeros(max_workspace_size,
-                                          device="cuda",
-                                          dtype=torch.int),
-                              requires_grad=False)
+            output_size_per_partition // self.quant_config.min_n_threads
+        ) * self.quant_config.max_parallel
+        workspace = BaseAphroditeParameter(
+            data=torch.zeros(
+                max_workspace_size, device="cuda", dtype=torch.int
+            ),
+            weight_loader=weight_loader,
+        )
 
         layer.register_parameter("B", qweight)
-        set_weight_attrs(qweight, extra_weight_attrs)
         layer.register_parameter("s", scales)
-        set_weight_attrs(scales, extra_weight_attrs)
         layer.register_parameter("workspace", workspace)
-        set_weight_attrs(workspace, extra_weight_attrs)
+
+    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+        # required by torch.compile
+        layer.B = Parameter(layer.B.data, requires_grad=False)
+        layer.s = Parameter(layer.s.data, requires_grad=False)
+        layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
 
     def apply(
         self,
@@ -242,10 +272,11 @@ class MarlinLinearMethod(LinearMethodBase):
         size_k = x_2d.shape[1]
         size_n = scales.shape[1]
 
-        output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m,
-                                    size_n, size_k)
+        output_2d = ops.marlin_gemm(
+            x_2d, qweight, scales, workspace, size_m, size_n, size_k
+        )
 
-        output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
+        output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],))
 
         if bias is not None:
             output.add_(bias)  # In-place add