Browse Source

feat: w4a16 support for compressed-tensors

AlpinDale 7 months ago
parent
commit
1d00b61622

+ 1 - 1
aphrodite/distributed/parallel_state.py

@@ -138,7 +138,7 @@ class GroupCoordinator:
 
         # lazy import to avoid documentation build error
         from aphrodite.distributed.device_communicators.custom_all_reduce import \
-            CustomAllreduce
+            CustomAllreduce  # noqa: E501
         from aphrodite.distributed.device_communicators.pynccl import \
             PyNcclCommunicator
 

+ 36 - 8
aphrodite/quantization/compressed_tensors/compressed_tensors.py

@@ -6,8 +6,8 @@ from pydantic import BaseModel
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
 from aphrodite.quantization.base_config import QuantizationConfig  # noqa: E501
 from aphrodite.quantization.compressed_tensors.schemes import (
-    CompressedTensorsScheme, CompressedTensorsW8A8DynamicToken,
-    CompressedTensorsW8A8StaticTensor)
+    CompressedTensorsScheme, CompressedTensorsW4A16,
+    CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor)
 from aphrodite.quantization.compressed_tensors.utils import (
     QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match)
 
@@ -46,16 +46,27 @@ class CompressedTensorsConfig(QuantizationConfig):
         layer_quant_details: Dict[str, Any] = dict()
         ignore: List[str] = config.get("ignore", None)
 
+        # The quant_config has multiple config_groups, each containing
+        # an input_activations key with details about how the activations are
+        # quantized, a weights key indicating how the weights are quantized,
+        # and a list of targets under the `targets` key, dictating which
+        # layers are impacted by the quantization details. The quantization
+        # details follow the structure defined by the QuantizationArgs
+        # pydantic model, which is used to verify the structure of the
+        # quant_config and also store the details for later use.
         for key, quant_config in config["config_groups"].items():
             targets = quant_config.get("targets")
             for target in targets:
                 layer_quant_details[target] = {}
                 layer_quant_details[target][
-                    "weight"] = QuantizationArgs.parse_obj(
+                    "weights"] = QuantizationArgs.parse_obj(
                         quant_config.get("weights"))
-                layer_quant_details[target][
-                    "input"] = QuantizationArgs.parse_obj(
-                        quant_config.get("input_activations"))
+                try:
+                    layer_quant_details[target][
+                        "input_activations"] = QuantizationArgs.parse_obj(
+                            quant_config.get("input_activations"))
+                except Exception:
+                    layer_quant_details[target]["input_activations"] = None
 
         return cls(layer_quant_details=layer_quant_details, ignore=ignore)
 
@@ -85,8 +96,23 @@ class CompressedTensorsConfig(QuantizationConfig):
 
         return is_8_bits and is_token_tensor and is_symmetric and is_dynamic
 
+    def _is_w4a16(self, weight_quant: BaseModel,
+                  input_quant: BaseModel) -> bool:
+        input_quant_none = input_quant is None
+        is_4_bits = weight_quant.num_bits == 4
+        is_symmetric = weight_quant.symmetric
+        is_static = not weight_quant.dynamic
+
+        return is_4_bits and input_quant_none and is_symmetric and is_static
+
     def _get_schema(self, weight_quant: BaseModel,
                     input_quant: BaseModel) -> "CompressedTensorsScheme":
+
+        if self._is_w4a16(weight_quant, input_quant):
+            return CompressedTensorsW4A16(num_bits=weight_quant.num_bits,
+                                          strategy=weight_quant.strategy,
+                                          group_size=weight_quant.group_size)
+
         if self._is_static_tensor_w8a8(weight_quant, input_quant):
             return CompressedTensorsW8A8StaticTensor()
 
@@ -112,8 +138,9 @@ class CompressedTensorsConfig(QuantizationConfig):
             raise ValueError(
                 f"Could not find quantization details for {layer}.")
 
-        return self._get_schema(weight_quant=layer_quant_details["weight"],
-                                input_quant=layer_quant_details["input"])
+        return self._get_schema(
+            weight_quant=layer_quant_details["weights"],
+            input_quant=layer_quant_details["input_activations"])
 
 
 class CompressedTensorsLinearMethod(LinearMethodBase):
@@ -139,6 +166,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
             layer=layer,
             input_size_per_partition=input_size_per_partition,
             output_partition_sizes=output_partition_sizes,
+            input_size=input_size,
             output_size=output_size,
             params_dtype=params_dtype,
             weight_loader=weight_loader)

+ 7 - 6
aphrodite/quantization/compressed_tensors/schemes/__init__.py

@@ -1,7 +1,8 @@
 from .compressed_tensors_scheme import CompressedTensorsScheme  # noqa: F401
-from .compressed_tensors_unquantized import (  # noqa: F401
-    CompressedTensorsUnquantized)
-from .compressed_tensors_w8a8_dynamictoken import (  # noqa: F401, E501
-    CompressedTensorsW8A8DynamicToken)
-from .compressed_tensors_w8a8_statictensor import (  # noqa: F401, E501
-    CompressedTensorsW8A8StaticTensor)
+from .compressed_tensors_unquantized import \
+    CompressedTensorsUnquantized  # noqa: F401
+from .compressed_tensors_w4a16 import CompressedTensorsW4A16  # noqa: F401
+from .compressed_tensors_w8a8_dynamictoken import \
+    CompressedTensorsW8A8DynamicToken  # noqa: F401, E501
+from .compressed_tensors_w8a8_statictensor import \
+    CompressedTensorsW8A8StaticTensor  # noqa: F401, E501

+ 169 - 0
aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_w4a16.py

@@ -0,0 +1,169 @@
+from typing import Callable, List, Optional
+
+import torch
+from torch.nn import Parameter
+
+from aphrodite import _custom_ops as ops
+from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.compressed_tensors.schemes import \
+    CompressedTensorsScheme
+from aphrodite.quantization.gptq_marlin import (GPTQ_MARLIN_MAX_PARALLEL,
+                                                GPTQ_MARLIN_MIN_THREAD_N,
+                                                GPTQMarlinState,
+                                                marlin_permute_scales)
+
+__all__ = ["CompressedTensorsW4A16"]
+
+
+class CompressedTensorsW4A16(CompressedTensorsScheme):
+
+    def __init__(self,
+                 strategy: str,
+                 num_bits: int,
+                 group_size: Optional[int] = None):
+        self.num_bits = num_bits
+        self.strategy = strategy
+        self.group_size = group_size
+
+        if self.strategy == "group" and self.group_size is None:
+            raise ValueError(
+                "group_size must be given when using strategy group")
+
+    def create_weights(self, layer: torch.nn.Module, input_size: int,
+                       output_partition_sizes: List[int],
+                       input_size_per_partition: int,
+                       params_dtype: torch.dtype, weight_loader: Callable,
+                       **kwargs):
+
+        pack_factor = 32 // self.num_bits
+        output_size_per_partition = sum(output_partition_sizes)
+
+        if self.group_size is not None:
+            group_size = self.group_size
+        else:
+            group_size = input_size
+
+        weight_scale_dim = None
+        scales_and_zp_size = input_size // group_size
+
+        if (input_size != input_size_per_partition
+                and self.group_size is not None):
+            weight_scale_dim = 1
+            scales_and_zp_size = input_size_per_partition // group_size
+
+        weight = Parameter(
+            torch.empty(
+                output_size_per_partition,
+                input_size_per_partition // pack_factor,
+                dtype=torch.int32,
+            ),
+            requires_grad=False,
+        )
+
+        set_weight_attrs(
+            weight, {
+                "input_dim": 1,
+                "output_dim": 0,
+                "packed_dim": 1,
+                "pack_factor": pack_factor
+            })
+        set_weight_attrs(weight, {"weight_loader": weight_loader})
+
+        layer.register_parameter("weight_packed", weight)
+
+        weight_scale = Parameter(
+            torch.empty(
+                output_size_per_partition,
+                scales_and_zp_size,
+                dtype=params_dtype,
+            ),
+            requires_grad=False,
+        )
+
+        set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
+        set_weight_attrs(weight_scale, {
+            "input_dim": weight_scale_dim,
+            "output_dim": 0
+        })
+        layer.register_parameter("weight_scale", weight_scale)
+
+        # A 2D array defining the original shape of the weights
+        # before packing
+        weight_shape = Parameter(torch.empty(2, dtype=torch.int64),
+                                 requires_grad=False)
+
+        layer.register_parameter("weight_shape", weight_shape)
+        set_weight_attrs(weight_shape, {"weight_loader": weight_loader})
+
+        layer.input_size_per_partition = input_size_per_partition
+        layer.output_size_per_partition = output_size_per_partition
+
+        layer.input_size = input_size
+        layer.marlin_state = GPTQMarlinState.REPACK
+        layer.is_k_full = True
+        layer.group_size = group_size
+
+        max_workspace_size = (
+            output_size_per_partition //
+            GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
+
+        workspace = torch.zeros(max_workspace_size,
+                                dtype=torch.int,
+                                requires_grad=False)
+        layer.workspace = workspace
+
+    def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
+        reshaped_x = x.reshape(-1, x.shape[-1])
+
+        size_m = reshaped_x.shape[0]
+        part_size_n = layer.output_size_per_partition
+        part_size_k = layer.input_size_per_partition
+
+        out_shape = x.shape[:-1] + (part_size_n, )
+
+        if layer.marlin_state == GPTQMarlinState.REPACK:
+            layer.marlin_state = GPTQMarlinState.READY
+
+            # Newly generated tensors need to replace existing tensors that are
+            # already registered as parameters by vLLM (and won't be freed)
+            def replace_tensor(name, new_t):
+                # It is important to use resize_() here since it ensures
+                # the same buffer is reused
+                getattr(layer, name).resize_(new_t.shape)
+                getattr(layer, name).copy_(new_t)
+                del new_t
+
+            cur_device = layer.weight_packed.device
+
+            # Reset g_idx related tensors
+            layer.g_idx = Parameter(torch.empty(0,
+                                                dtype=torch.int,
+                                                device=cur_device),
+                                    requires_grad=False)
+            layer.g_idx_sort_indices = Parameter(torch.empty(
+                0, dtype=torch.int, device=cur_device),
+                                                 requires_grad=False)
+
+            # Repack weights
+            marlin_qweight = ops.gptq_marlin_repack(
+                layer.weight_packed.t().contiguous(), layer.g_idx_sort_indices,
+                part_size_k, part_size_n, self.num_bits)
+
+            replace_tensor("weight_packed", marlin_qweight)
+
+            # Permute scales
+            scales_size_k = part_size_k
+            scales_size_n = part_size_n
+
+            marlin_scales = marlin_permute_scales(
+                layer.weight_scale.squeeze().t().contiguous(), scales_size_k,
+                scales_size_n, layer.group_size, self.num_bits)
+            replace_tensor("weight_scale", marlin_scales)
+
+        output = ops.gptq_marlin_gemm(reshaped_x, layer.weight_packed,
+                                      layer.weight_scale, layer.g_idx,
+                                      layer.g_idx_sort_indices,
+                                      layer.workspace, self.num_bits, size_m,
+                                      part_size_n, part_size_k,
+                                      layer.is_k_full)
+        return output.reshape(out_shape)