Преглед на файлове

chore: use the `compressed-tensors` library to avoid code reuse (#704)

AlpinDale преди 6 месеца
родител
ревизия
f5bbf07c90

+ 5 - 2
aphrodite/quantization/compressed_tensors/compressed_tensors.py

@@ -1,6 +1,10 @@
 from typing import Any, Dict, List, Optional
 
 import torch
+from compressed_tensors.config import CompressionFormat
+from compressed_tensors.quantization import (QuantizationArgs,
+                                             QuantizationStrategy,
+                                             QuantizationType)
 from pydantic import BaseModel
 
 from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
@@ -14,8 +18,7 @@ from aphrodite.quantization.compressed_tensors.schemes import (
     CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
     CompressedTensorsWNA16)
 from aphrodite.quantization.compressed_tensors.utils import (
-    CompressionFormat, QuantizationArgs, QuantizationStrategy,
-    QuantizationType, find_matched_target, is_activation_quantization_format,
+    find_matched_target, is_activation_quantization_format,
     should_ignore_layer)
 from aphrodite.quantization.kv_cache import BaseKVCacheMethod
 

+ 1 - 2
aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py

@@ -1,14 +1,13 @@
 from typing import Callable, List, Optional
 
 import torch
+from compressed_tensors.quantization import QuantizationStrategy
 
 from aphrodite.modeling.parameter import (ChannelQuantScaleParameter,
                                           ModelWeightParameter,
                                           PerTensorScaleParameter)
 from aphrodite.quantization.compressed_tensors.schemes import (
     CompressedTensorsScheme)
-from aphrodite.quantization.compressed_tensors.utils import (
-    QuantizationStrategy)
 from aphrodite.quantization.utils.marlin_utils_fp8 import (
     apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
 from aphrodite.quantization.utils.w8a8_utils import convert_to_channelwise

+ 1 - 2
aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py

@@ -1,6 +1,7 @@
 from typing import Callable, List, Optional
 
 import torch
+from compressed_tensors.quantization import QuantizationStrategy
 from torch.nn import Parameter
 
 from aphrodite.modeling.parameter import (ChannelQuantScaleParameter,
@@ -8,8 +9,6 @@ from aphrodite.modeling.parameter import (ChannelQuantScaleParameter,
                                           PerTensorScaleParameter)
 from aphrodite.quantization.compressed_tensors.schemes import (
     CompressedTensorsScheme)
-from aphrodite.quantization.compressed_tensors.utils import (
-    QuantizationStrategy)
 from aphrodite.quantization.utils.w8a8_utils import (apply_fp8_linear,
                                                      cutlass_fp8_supported,
                                                      requantize_with_max_scale)

+ 1 - 2
aphrodite/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py

@@ -1,6 +1,7 @@
 from typing import Callable, List, Optional
 
 import torch
+from compressed_tensors.quantization import QuantizationStrategy
 from torch.nn import Parameter
 
 from aphrodite.modeling.parameter import (BaseAphroditeParameter,
@@ -9,8 +10,6 @@ from aphrodite.modeling.parameter import (BaseAphroditeParameter,
                                           PerTensorScaleParameter)
 from aphrodite.quantization.compressed_tensors.schemes import (
     CompressedTensorsScheme)
-from aphrodite.quantization.compressed_tensors.utils import (
-    QuantizationStrategy)
 from aphrodite.quantization.utils.w8a8_utils import (apply_int8_linear,
                                                      convert_to_channelwise)
 

+ 2 - 74
aphrodite/quantization/compressed_tensors/utils.py

@@ -1,84 +1,12 @@
 import re
-from enum import Enum
-from typing import Any, Dict, Iterable, Optional
+from typing import Iterable, Optional
 
-from pydantic import BaseModel, Field
+from compressed_tensors import CompressionFormat
 from torch.nn import Module
 
 from aphrodite.quantization.utils.quant_utils import FUSED_LAYER_NAME_MAPPING
 
 
-class CompressionFormat(Enum):
-    dense = "dense"
-    sparse_bitmask = "sparse-bitmask"
-    naive_quantized = "naive-quantized"
-    float_quantized = "float-quantized"
-    int_quantized = "int-quantized"
-    pack_quantized = "pack-quantized"
-    marlin_24 = "marlin-24"
-
-
-class QuantizationType(str, Enum):
-    """
-    Enum storing quantization type options
-    """
-
-    INT = "int"
-    FLOAT = "float"
-
-
-class QuantizationStrategy(str, Enum):
-    """
-    Enum storing quantization strategy options
-    """
-
-    TENSOR = "tensor"
-    CHANNEL = "channel"
-    GROUP = "group"
-    BLOCK = "block"
-    TOKEN = "token"
-
-
-class QuantizationArgs(BaseModel):
-    """
-    User facing arguments used to define a quantization config 
-    for weights or activations
-
-    :param num_bits: quantization bit depth
-    :param type: dtype to quantized to, either int or float
-    :param symmetric: whether or not quantization scale is symmetric
-    :param strategy: string determining the scope of scale/zero-point to apply
-    :param group_size: group length to use for the group strategy
-    :param block_structure: 2d block structure to use for the block 
-    strategy, must be of the format "2x4", "8x16", etc.
-    :param dynamic: set True to perform dynamic quantization -
-        values will not be calibrated during calibration phase, 
-        instead during inference new quantization ranges will be 
-        observed with every sample. Defaults to False for static
-        quantization. Note that enabling dynamic quantization 
-        will change the default observer to a memoryless one
-    """
-
-    num_bits: int = 8
-    type: QuantizationType = QuantizationType.INT
-    symmetric: bool = True
-    group_size: Optional[int] = None
-    strategy: Optional[QuantizationStrategy] = None
-    block_structure: Optional[str] = None
-    dynamic: bool = False
-    observer: str = Field(
-        default="minmax",
-        description=("The class to use to compute the quantization param - "
-                     "scale and zero-point'"),
-    )
-    observer_kwargs: Dict[str, Any] = Field(
-        default_factory=dict,
-        description=
-        ("optional dict of kwargs to be passed directly to torch quantization "
-         "Observers constructor excluding quantization range or symmetry"),
-    )
-
-
 def is_activation_quantization_format(format: str) -> bool:
     _ACTIVATION_QUANTIZATION_FORMATS = [
         CompressionFormat.naive_quantized.value,

+ 2 - 1
requirements-common.txt

@@ -26,4 +26,5 @@ loguru
 hf_transfer # for faster downloads
 librosa  # Required for audio processing
 soundfile  # Required for audio processing
-gguf == 0.9.1
+gguf == 0.9.1
+compressed-tensors == 0.5.0

+ 1 - 1
requirements-test.txt

@@ -18,7 +18,7 @@ requests
 ray
 sentence-transformers # required for embedding
 sparseml==1.8.0 # required for compressed-tensors
-compressed-tensors==0.4.0 # required for compressed-tensors
+compressed-tensors==0.5.0 # required for compressed-tensors
 timm # required for internvl test
 
 # Benchmarking