浏览代码

feat: per-channel support for static activation quant

AlpinDale 7 月之前
父节点
当前提交
b753ff7870

+ 7 - 3
aphrodite/quantization/compressed_tensors/compressed_tensors.py

@@ -84,8 +84,11 @@ class CompressedTensorsConfig(QuantizationConfig):
     def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
     def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
                                input_quant: BaseModel) -> bool:
                                input_quant: BaseModel) -> bool:
         is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
         is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
-        is_tensor = (weight_quant.strategy == input_quant.strategy ==
-                     QuantizationStrategy.TENSOR.value)
+        weight_strategy = (
+            weight_quant.strategy == QuantizationStrategy.TENSOR.value
+            or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
+        is_tensor = (weight_strategy and input_quant.strategy
+                     == QuantizationStrategy.TENSOR.value)
         is_symmetric = weight_quant.symmetric and input_quant.symmetric
         is_symmetric = weight_quant.symmetric and input_quant.symmetric
         is_static = not weight_quant.dynamic and not input_quant.dynamic
         is_static = not weight_quant.dynamic and not input_quant.dynamic
 
 
@@ -130,7 +133,8 @@ class CompressedTensorsConfig(QuantizationConfig):
 
 
         if self.quant_format == CompressionFormat.int_quantized.value:
         if self.quant_format == CompressionFormat.int_quantized.value:
             if self._is_static_tensor_w8a8(weight_quant, input_quant):
             if self._is_static_tensor_w8a8(weight_quant, input_quant):
-                return CompressedTensorsW8A8StaticTensor()
+                return CompressedTensorsW8A8StaticTensor(
+                    strategy=weight_quant.strategy)
 
 
             if self._is_dynamic_token_w8a8(weight_quant, input_quant):
             if self._is_dynamic_token_w8a8(weight_quant, input_quant):
                 return CompressedTensorsW8A8DynamicToken(
                 return CompressedTensorsW8A8DynamicToken(

+ 83 - 0
aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py

@@ -0,0 +1,83 @@
+
+import torch
+from torch.nn import Parameter
+
+from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.compressed_tensors.schemes import \
+    CompressedTensorsScheme
+from aphrodite.quantization.compressed_tensors.utils import \
+    QuantizationStrategy
+
+
+class CompressedTensorsW8A8(CompressedTensorsScheme):
+
+    def __init__(self, strategy: str):
+        self.strategy = strategy
+
+    def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
+        if isinstance(shard_id, int):
+            return shard_id
+
+        assert isinstance(shard_id, str)
+        qkv_idxs = {"q": 0, "k": 1, "v": 2}
+        assert shard_id in qkv_idxs
+        return qkv_idxs[shard_id]
+
+    def scales_shard_splitter(
+            self, param: torch.Tensor, loaded_weight: torch.Tensor,
+            shard_id: Union[str, int],
+            logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        shard_id = self._shard_id_as_int(shard_id)
+        offset = sum(logical_widths[:shard_id])
+        size = logical_widths[shard_id]
+        # update loaded weight with copies for broadcast.
+        loaded_weight = loaded_weight.repeat(size)
+        return param[offset:offset + size], loaded_weight
+
+    def create_weights(self, layer: torch.nn.Module,
+                       output_partition_sizes: List[int],
+                       input_size_per_partition: int,
+                       params_dtype: torch.dtype, weight_loader: Callable,
+                       **kwargs):
+
+        is_tensor_partitioned = len(output_partition_sizes) != 1
+        weight_scale_dim = sum(output_partition_sizes) if (
+            is_tensor_partitioned
+            or self.strategy == QuantizationStrategy.CHANNEL) else 1
+
+        shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, )
+        if self.strategy == QuantizationStrategy.CHANNEL:
+            shape = (weight_scale_dim, 1)
+
+        weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
+                                 requires_grad=False)
+
+        layer.register_parameter("weight_scale", weight_scale)
+        set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
+
+        weight = Parameter(torch.empty(sum(output_partition_sizes),
+                                       input_size_per_partition,
+                                       dtype=torch.int8),
+                           requires_grad=False)
+
+        layer.register_parameter("weight", weight)
+        set_weight_attrs(
+            weight, {
+                "input_dim": 1,
+                "output_dim": 0,
+                "weight_loader": weight_loader,
+                "logical_widths": output_partition_sizes
+            })
+
+        # Don't need a shard_splitter for channel-wise quantization
+        # Use the default loading method
+        if self.strategy == QuantizationStrategy.CHANNEL:
+            set_weight_attrs(weight_scale, {
+                "output_dim": 0,
+            })
+        else:
+            set_weight_attrs(
+                weight_scale, {
+                    "logical_widths": output_partition_sizes,
+                    "shard_splitter": self.scales_shard_splitter,
+                })

+ 10 - 79
aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py

@@ -1,42 +1,15 @@
-from typing import Callable, List, Tuple, Union
+from typing import Callable, List
 
 
 import torch
 import torch
-from torch.nn import Parameter
 
 
 from aphrodite import _custom_ops as custom_ops
 from aphrodite import _custom_ops as custom_ops
-from aphrodite.modeling.utils import set_weight_attrs
-from aphrodite.quantization.compressed_tensors.schemes import \
-    CompressedTensorsScheme
-from aphrodite.quantization.compressed_tensors.utils import \
-    QuantizationStrategy
+from aphrodite.quantization.compressed_tensors.schemes.compressed_tensors_w8a8 import \
+    CompressedTensorsW8A8  # noqa: E501
 
 
 __all__ = ["CompressedTensorsW8A8DynamicToken"]
 __all__ = ["CompressedTensorsW8A8DynamicToken"]
 
 
 
 
-class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
-
-    def __init__(self, strategy: str):
-        self.strategy = strategy
-
-    def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
-        if isinstance(shard_id, int):
-            return shard_id
-
-        assert isinstance(shard_id, str)
-        qkv_idxs = {"q": 0, "k": 1, "v": 2}
-        assert shard_id in qkv_idxs
-        return qkv_idxs[shard_id]
-
-    def scales_shard_splitter(
-            self, param: torch.Tensor, loaded_weight: torch.Tensor,
-            shard_id: Union[str, int],
-            logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
-        shard_id = self._shard_id_as_int(shard_id)
-        offset = sum(logical_widths[:shard_id])
-        size = logical_widths[shard_id]
-        # update loaded weight with copies for broadcast.
-        loaded_weight = loaded_weight.repeat(size)
-        return param[offset:offset + size], loaded_weight
+class CompressedTensorsW8A8DynamicToken(CompressedTensorsW8A8):
 
 
     def create_weights(self, layer: torch.nn.Module,
     def create_weights(self, layer: torch.nn.Module,
                        output_partition_sizes: List[int],
                        output_partition_sizes: List[int],
@@ -44,54 +17,12 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
                        params_dtype: torch.dtype, weight_loader: Callable,
                        params_dtype: torch.dtype, weight_loader: Callable,
                        **kwargs):
                        **kwargs):
 
 
-        # When the scales have a single value, it is required that they be
-        # on the CPU for performance and CUDA Graphs compatibility. Please
-        # refer to the comment in
-        # CompressedTensorsW8A8StaticTensor::create_weights for further
-        # information.
-        is_tensor_partitioned = len(output_partition_sizes) != 1
-        # when doing channel-wise quantization, number of scales
-        # is equal to output_dim
-        weight_scale_dim = sum(output_partition_sizes) if (
-            is_tensor_partitioned
-            or self.strategy == QuantizationStrategy.CHANNEL) else 1
-
-        shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, )
-        if self.strategy == QuantizationStrategy.CHANNEL:
-            shape = (weight_scale_dim, 1)
-
-        weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
-                                 requires_grad=False)
-
-        weight = Parameter(torch.empty(sum(output_partition_sizes),
-                                       input_size_per_partition,
-                                       dtype=torch.int8),
-                           requires_grad=False)
-
-        layer.register_parameter("weight", weight)
-        set_weight_attrs(
-            weight, {
-                "input_dim": 1,
-                "output_dim": 0,
-                "weight_loader": weight_loader,
-                "logical_widths": output_partition_sizes
-            })
-
-        layer.register_parameter("weight_scale", weight_scale)
-        set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
-
-        # Don't need a shard_splitter for channel-wise quantization
-        # Use the default loading method
-        if self.strategy == QuantizationStrategy.CHANNEL:
-            set_weight_attrs(weight_scale, {
-                "output_dim": 0,
-            })
-        else:
-            set_weight_attrs(
-                weight_scale, {
-                    "logical_widths": output_partition_sizes,
-                    "shard_splitter": self.scales_shard_splitter,
-                })
+        super().create_weights(
+            layer=layer,
+            output_partition_sizes=output_partition_sizes,
+            input_size_per_partition=input_size_per_partition,
+            params_dtype=params_dtype,
+            weight_loader=weight_loader)
 
 
     def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
     def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
         weight = layer.weight
         weight = layer.weight

+ 11 - 51
aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py

@@ -1,37 +1,17 @@
-from typing import Callable, List, Tuple, Union
+from typing import Callable, List
 
 
 import torch
 import torch
 from torch.nn import Parameter
 from torch.nn import Parameter
 
 
 from aphrodite import _custom_ops as custom_ops
 from aphrodite import _custom_ops as custom_ops
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.modeling.utils import set_weight_attrs
-from aphrodite.quantization.compressed_tensors.schemes import \
-    CompressedTensorsScheme
+from aphrodite.quantization.compressed_tensors.schemes.compressed_tensors_w8a8 import \
+    CompressedTensorsW8A8  # noqa: E501
 
 
 __all__ = ["CompressedTensorsW8A8StaticTensor"]
 __all__ = ["CompressedTensorsW8A8StaticTensor"]
 
 
 
 
-class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
-
-    def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
-        if isinstance(shard_id, int):
-            return shard_id
-
-        assert isinstance(shard_id, str)
-        qkv_idxs = {"q": 0, "k": 1, "v": 2}
-        assert shard_id in qkv_idxs
-        return qkv_idxs[shard_id]
-
-    def scales_shard_splitter(
-            self, param: torch.Tensor, loaded_weight: torch.Tensor,
-            shard_id: Union[str, int],
-            logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
-        shard_id = self._shard_id_as_int(shard_id)
-        offset = sum(logical_widths[:shard_id])
-        size = logical_widths[shard_id]
-        # update loaded weight with copies for broadcast.
-        loaded_weight = loaded_weight.repeat(size)
-        return param[offset:offset + size], loaded_weight
+class CompressedTensorsW8A8StaticTensor(CompressedTensorsW8A8):
 
 
     def create_weights(self, layer: torch.nn.Module,
     def create_weights(self, layer: torch.nn.Module,
                        output_partition_sizes: List[int],
                        output_partition_sizes: List[int],
@@ -39,41 +19,21 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
                        params_dtype: torch.dtype, weight_loader: Callable,
                        params_dtype: torch.dtype, weight_loader: Callable,
                        **kwargs):
                        **kwargs):
 
 
-        is_tensor_partitioned = len(output_partition_sizes) != 1
-        weight_scale_dim = sum(
-            output_partition_sizes) if is_tensor_partitioned else 1
+        super().create_weights(
+            layer=layer,
+            output_partition_sizes=output_partition_sizes,
+            input_size_per_partition=input_size_per_partition,
+            params_dtype=params_dtype,
+            weight_loader=weight_loader)
 
 
         input_scale = Parameter(torch.empty(1, dtype=torch.float32),
         input_scale = Parameter(torch.empty(1, dtype=torch.float32),
                                 requires_grad=False)
                                 requires_grad=False)
 
 
-        weight_scale = Parameter(torch.empty(weight_scale_dim,
-                                             dtype=torch.float32),
-                                 requires_grad=False)
-
-        weight = Parameter(torch.empty(sum(output_partition_sizes),
-                                       input_size_per_partition,
-                                       dtype=torch.int8),
-                           requires_grad=False)
-
-        layer.register_parameter("weight", weight)
-        set_weight_attrs(weight, {
-            "weight_loader": weight_loader,
-            "input_dim": 1,
-            "output_dim": 0,
-        })
         layer.register_parameter("input_scale", input_scale)
         layer.register_parameter("input_scale", input_scale)
         set_weight_attrs(input_scale, {
         set_weight_attrs(input_scale, {
             "weight_loader": weight_loader,
             "weight_loader": weight_loader,
             "ignore_warning": True,
             "ignore_warning": True,
         })
         })
-        layer.register_parameter("weight_scale", weight_scale)
-        set_weight_attrs(
-            weight_scale, {
-                "weight_loader": weight_loader,
-                "shard_splitter": self.scales_shard_splitter,
-                "logical_widths": output_partition_sizes,
-                "ignore_warning": True,
-            })
 
 
     def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
     def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
         weight = layer.weight
         weight = layer.weight
@@ -84,4 +44,4 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
         x_q, _ = custom_ops.scaled_int8_quant(x, act_scale)
         x_q, _ = custom_ops.scaled_int8_quant(x, act_scale)
 
 
         return custom_ops.cutlass_scaled_mm(x_q, weight.t(), act_scale,
         return custom_ops.cutlass_scaled_mm(x_q, weight.t(), act_scale,
-                                            weight_scale, x.dtype)
+                                            weight_scale, x.dtype)