Browse Source

feat: add sparse marlin for compressed tensors

AlpinDale 7 months ago
parent
commit
e2dbe5f05c

+ 32 - 16
aphrodite/quantization/compressed_tensors/compressed_tensors.py

@@ -7,16 +7,20 @@ from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
 from aphrodite.quantization.base_config import QuantizationConfig  # noqa: E501
 from aphrodite.quantization.compressed_tensors.schemes import (
     CompressedTensorsScheme, CompressedTensorsW4A16,
-    CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor)
+    CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8DynamicToken,
+    CompressedTensorsW8A8StaticTensor)
 from aphrodite.quantization.compressed_tensors.utils import (
-    QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match)
+    CompressionFormat, QuantizationArgs, QuantizationStrategy,
+    find_first_name_or_class_match)
 
 
 class CompressedTensorsConfig(QuantizationConfig):
 
-    def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str]):
+    def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str],
+                 quant_format: str):
         self.ignore = ignore
         self.layer_quant_details = layer_quant_details
+        self.quant_format = quant_format
 
     def get_linear_method(self) -> "CompressedTensorsLinearMethod":
         return CompressedTensorsLinearMethod(self)
@@ -45,6 +49,7 @@ class CompressedTensorsConfig(QuantizationConfig):
     def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
         layer_quant_details: Dict[str, Any] = dict()
         ignore: List[str] = config.get("ignore", None)
+        quant_format: str = config.get("format", None)
 
         # The quant_config has multiple config_groups, each containing
         # an input_activations key with details about how the activations are
@@ -68,7 +73,9 @@ class CompressedTensorsConfig(QuantizationConfig):
                 except Exception:
                     layer_quant_details[target]["input_activations"] = None
 
-        return cls(layer_quant_details=layer_quant_details, ignore=ignore)
+        return cls(layer_quant_details=layer_quant_details,
+                   ignore=ignore,
+                   quant_format=quant_format)
 
     @classmethod
     def get_config_filenames(cls) -> List[str]:
@@ -109,17 +116,26 @@ class CompressedTensorsConfig(QuantizationConfig):
                     input_quant: BaseModel) -> "CompressedTensorsScheme":
 
         if self._is_w4a16(weight_quant, input_quant):
-            return CompressedTensorsW4A16(num_bits=weight_quant.num_bits,
-                                          strategy=weight_quant.strategy,
-                                          group_size=weight_quant.group_size)
-
-        if self._is_static_tensor_w8a8(weight_quant, input_quant):
-            return CompressedTensorsW8A8StaticTensor()
-
-        if self._is_dynamic_token_w8a8(weight_quant, input_quant):
-            return CompressedTensorsW8A8DynamicToken()
-
-        raise NotImplementedError("Scheme not supported.")
+            if self.quant_format == CompressionFormat.marlin_24.value:
+                return CompressedTensorsW4A16Sparse24(
+                    strategy=weight_quant.strategy,
+                    num_bits=weight_quant.num_bits,
+                    group_size=weight_quant.group_size)
+            if self.quant_format == CompressionFormat.pack_quantized.value:
+                return CompressedTensorsW4A16(
+                    num_bits=weight_quant.num_bits,
+                    strategy=weight_quant.strategy,
+                    group_size=weight_quant.group_size)
+
+        if self.quant_format == CompressionFormat.int_quantized.value:
+            if self._is_static_tensor_w8a8(weight_quant, input_quant):
+                return CompressedTensorsW8A8StaticTensor()
+
+            if self._is_dynamic_token_w8a8(weight_quant, input_quant):
+                return CompressedTensorsW8A8DynamicToken()
+
+        raise NotImplementedError(
+            "No compressed-tensors compatible scheme was found.")
 
     def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
 
@@ -164,9 +180,9 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
         scheme = self.quantization_config.get_scheme(layer=layer)
         scheme.create_weights(
             layer=layer,
+            input_size=input_size,
             input_size_per_partition=input_size_per_partition,
             output_partition_sizes=output_partition_sizes,
-            input_size=input_size,
             output_size=output_size,
             params_dtype=params_dtype,
             weight_loader=weight_loader)

+ 2 - 0
aphrodite/quantization/compressed_tensors/schemes/__init__.py

@@ -2,6 +2,8 @@ from .compressed_tensors_scheme import CompressedTensorsScheme  # noqa: F401
 from .compressed_tensors_unquantized import \
     CompressedTensorsUnquantized  # noqa: F401
 from .compressed_tensors_w4a16 import CompressedTensorsW4A16  # noqa: F401
+from .compressed_tensors_w4a16_24 import \
+    CompressedTensorsW4A16Sparse24  # noqa: F401
 from .compressed_tensors_w8a8_dynamictoken import \
     CompressedTensorsW8A8DynamicToken  # noqa: F401, E501
 from .compressed_tensors_w8a8_statictensor import \

+ 134 - 0
aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py

@@ -0,0 +1,134 @@
+from typing import Callable, List, Optional
+
+import torch
+from torch.nn import Parameter
+
+from aphrodite import _custom_ops as ops
+from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.compressed_tensors.schemes import \
+    CompressedTensorsScheme
+from aphrodite.quantization.gptq_marlin_24 import (GPTQ_MARLIN_24_MAX_PARALLEL,
+                                                   GPTQ_MARLIN_24_MIN_THREAD_N)
+
+__all__ = ["CompressedTensorsW4A16Sparse24"]
+
+
+class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
+
+    def __init__(self,
+                 strategy: str,
+                 num_bits: int,
+                 group_size: Optional[int] = None):
+        self.strategy = strategy
+        self.group_size = group_size
+        self.num_bits = num_bits
+        self.tile_size = 16
+
+        if self.strategy == "group" and self.group_size is None:
+            raise ValueError(
+                "group_size must be given when using strategy group")
+
+    def create_weights(self, layer: torch.nn.Module, input_size: int,
+                       output_partition_sizes: List[int],
+                       input_size_per_partition: int,
+                       params_dtype: torch.dtype, weight_loader: Callable,
+                       **kwargs):
+
+        pack_factor = 32 // self.num_bits
+        output_size_per_partition = sum(output_partition_sizes)
+
+        qweight = Parameter(
+            torch.empty(
+                input_size_per_partition // self.tile_size // 2,
+                output_size_per_partition * self.tile_size // pack_factor,
+                dtype=torch.int32,
+            ),
+            requires_grad=False,
+        )
+        set_weight_attrs(
+            qweight,
+            {
+                "input_dim": 0,
+                "output_dim": 1,
+                "packed_dim": 1,
+                "pack_factor": pack_factor,
+                "marlin_tile_size": self.tile_size,
+                "weight_loader": weight_loader
+            },
+        )
+
+        layer.register_parameter("weight_packed", qweight)
+
+        input_groups = (1 if self.group_size is None else
+                        input_size_per_partition // self.group_size)
+
+        scales = Parameter(
+            torch.empty(
+                input_groups,
+                output_size_per_partition,
+                dtype=params_dtype,
+            ),
+            requires_grad=False,
+        )
+        set_weight_attrs(
+            scales,
+            {
+                "output_dim": 1,
+                "input_dim": None if input_groups == 1 else 0,
+                "weight_loader": weight_loader
+            },
+        )
+        layer.register_parameter("scale_packed", scales)
+
+        weight_shape = Parameter(torch.empty(2, dtype=torch.int64),
+                                 requires_grad=False)
+
+        layer.register_parameter("weight_shape", weight_shape)
+        set_weight_attrs(weight_shape, {"weight_loader": weight_loader})
+
+        meta = Parameter(
+            torch.empty(
+                input_size_per_partition // 8 // 2 // 2,
+                output_size_per_partition * 2,
+                dtype=torch.int16,
+            ),
+            requires_grad=False,
+        )
+        set_weight_attrs(
+            meta,
+            {
+                "input_dim": 0,
+                "packed_dim": 1,
+                "pack_factor": 1,
+                "output_dim": 1,
+                "marlin_tile_size": 2,
+                "weight_loader": weight_loader
+            },
+        )
+        layer.register_parameter("meta", meta)
+
+        max_workspace_size = (
+            output_size_per_partition //
+            GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL
+        workspace = Parameter(torch.zeros(max_workspace_size, dtype=torch.int),
+                              requires_grad=False)
+        layer.workspace = workspace
+
+    def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
+        qweight = layer.weight_packed
+        meta = layer.meta
+        scales = layer.scale_packed
+        workspace = layer.workspace
+
+        x_2d = x.view(-1, x.shape[-1])
+
+        size_m = x_2d.shape[0]
+        size_k = x_2d.shape[1]
+        size_n = scales.shape[1]
+
+        output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
+                                            workspace, self.num_bits, size_m,
+                                            size_n, size_k)
+
+        output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
+        return output

+ 8 - 0
aphrodite/quantization/compressed_tensors/utils.py

@@ -6,6 +6,14 @@ from pydantic import BaseModel, Field
 from torch.nn import Module
 
 
+class CompressionFormat(Enum):
+    dense = "dense"
+    sparse_bitmask = "sparse-bitmask"
+    int_quantized = "int-quantized"
+    pack_quantized = "pack-quantized"
+    marlin_24 = "marlin-24"
+
+
 class QuantizationType(str, Enum):
     """
     Enum storing quantization type options