|
@@ -7,16 +7,20 @@ from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
|
|
|
from aphrodite.quantization.base_config import QuantizationConfig # noqa: E501
|
|
|
from aphrodite.quantization.compressed_tensors.schemes import (
|
|
|
CompressedTensorsScheme, CompressedTensorsW4A16,
|
|
|
- CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor)
|
|
|
+ CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8DynamicToken,
|
|
|
+ CompressedTensorsW8A8StaticTensor)
|
|
|
from aphrodite.quantization.compressed_tensors.utils import (
|
|
|
- QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match)
|
|
|
+ CompressionFormat, QuantizationArgs, QuantizationStrategy,
|
|
|
+ find_first_name_or_class_match)
|
|
|
|
|
|
|
|
|
class CompressedTensorsConfig(QuantizationConfig):
|
|
|
|
|
|
- def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str]):
|
|
|
+ def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str],
|
|
|
+ quant_format: str):
|
|
|
self.ignore = ignore
|
|
|
self.layer_quant_details = layer_quant_details
|
|
|
+ self.quant_format = quant_format
|
|
|
|
|
|
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
|
|
return CompressedTensorsLinearMethod(self)
|
|
@@ -45,6 +49,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
|
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
|
|
|
layer_quant_details: Dict[str, Any] = dict()
|
|
|
ignore: List[str] = config.get("ignore", None)
|
|
|
+ quant_format: str = config.get("format", None)
|
|
|
|
|
|
# The quant_config has multiple config_groups, each containing
|
|
|
# an input_activations key with details about how the activations are
|
|
@@ -68,7 +73,9 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
|
except Exception:
|
|
|
layer_quant_details[target]["input_activations"] = None
|
|
|
|
|
|
- return cls(layer_quant_details=layer_quant_details, ignore=ignore)
|
|
|
+ return cls(layer_quant_details=layer_quant_details,
|
|
|
+ ignore=ignore,
|
|
|
+ quant_format=quant_format)
|
|
|
|
|
|
@classmethod
|
|
|
def get_config_filenames(cls) -> List[str]:
|
|
@@ -109,17 +116,26 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
|
input_quant: BaseModel) -> "CompressedTensorsScheme":
|
|
|
|
|
|
if self._is_w4a16(weight_quant, input_quant):
|
|
|
- return CompressedTensorsW4A16(num_bits=weight_quant.num_bits,
|
|
|
- strategy=weight_quant.strategy,
|
|
|
- group_size=weight_quant.group_size)
|
|
|
-
|
|
|
- if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
|
|
- return CompressedTensorsW8A8StaticTensor()
|
|
|
-
|
|
|
- if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
|
|
- return CompressedTensorsW8A8DynamicToken()
|
|
|
-
|
|
|
- raise NotImplementedError("Scheme not supported.")
|
|
|
+ if self.quant_format == CompressionFormat.marlin_24.value:
|
|
|
+ return CompressedTensorsW4A16Sparse24(
|
|
|
+ strategy=weight_quant.strategy,
|
|
|
+ num_bits=weight_quant.num_bits,
|
|
|
+ group_size=weight_quant.group_size)
|
|
|
+ if self.quant_format == CompressionFormat.pack_quantized.value:
|
|
|
+ return CompressedTensorsW4A16(
|
|
|
+ num_bits=weight_quant.num_bits,
|
|
|
+ strategy=weight_quant.strategy,
|
|
|
+ group_size=weight_quant.group_size)
|
|
|
+
|
|
|
+ if self.quant_format == CompressionFormat.int_quantized.value:
|
|
|
+ if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
|
|
+ return CompressedTensorsW8A8StaticTensor()
|
|
|
+
|
|
|
+ if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
|
|
+ return CompressedTensorsW8A8DynamicToken()
|
|
|
+
|
|
|
+ raise NotImplementedError(
|
|
|
+ "No compressed-tensors compatible scheme was found.")
|
|
|
|
|
|
def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
|
|
|
|
|
@@ -164,9 +180,9 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
|
|
|
scheme = self.quantization_config.get_scheme(layer=layer)
|
|
|
scheme.create_weights(
|
|
|
layer=layer,
|
|
|
+ input_size=input_size,
|
|
|
input_size_per_partition=input_size_per_partition,
|
|
|
output_partition_sizes=output_partition_sizes,
|
|
|
- input_size=input_size,
|
|
|
output_size=output_size,
|
|
|
params_dtype=params_dtype,
|
|
|
weight_loader=weight_loader)
|