compressed_tensors.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from pydantic import BaseModel
  4. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  5. from aphrodite.quantization.base_config import QuantizationConfig # noqa: E501
  6. from aphrodite.quantization.compressed_tensors.schemes import (
  7. CompressedTensorsScheme, CompressedTensorsW4A16,
  8. CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8DynamicToken,
  9. CompressedTensorsW8A8StaticTensor)
  10. from aphrodite.quantization.compressed_tensors.utils import (
  11. CompressionFormat, QuantizationArgs, QuantizationStrategy,
  12. find_first_name_or_class_match)
  13. class CompressedTensorsConfig(QuantizationConfig):
  14. def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str],
  15. quant_format: str):
  16. self.ignore = ignore
  17. self.layer_quant_details = layer_quant_details
  18. self.quant_format = quant_format
  19. def get_linear_method(self) -> "CompressedTensorsLinearMethod":
  20. return CompressedTensorsLinearMethod(self)
  21. def get_scaled_act_names(self) -> List[str]:
  22. return []
  23. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  24. return [torch.float16, torch.bfloat16]
  25. # Need to figure it out
  26. def get_min_capability(self) -> int:
  27. return 60
  28. def get_name(self) -> str:
  29. return "compressed_tensors"
  30. def get_quant_method(
  31. self, layer: torch.nn.Module
  32. ) -> Optional["CompressedTensorsLinearMethod"]:
  33. if isinstance(layer, LinearBase):
  34. return CompressedTensorsLinearMethod(self)
  35. return None
  36. @classmethod
  37. def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
  38. layer_quant_details: Dict[str, Any] = dict()
  39. ignore: List[str] = config.get("ignore", None)
  40. quant_format: str = config.get("format", None)
  41. # The quant_config has multiple config_groups, each containing
  42. # an input_activations key with details about how the activations are
  43. # quantized, a weights key indicating how the weights are quantized,
  44. # and a list of targets under the `targets` key, dictating which
  45. # layers are impacted by the quantization details. The quantization
  46. # details follow the structure defined by the QuantizationArgs
  47. # pydantic model, which is used to verify the structure of the
  48. # quant_config and also store the details for later use.
  49. for key, quant_config in config["config_groups"].items():
  50. targets = quant_config.get("targets")
  51. for target in targets:
  52. layer_quant_details[target] = {}
  53. layer_quant_details[target][
  54. "weights"] = QuantizationArgs.parse_obj(
  55. quant_config.get("weights"))
  56. try:
  57. layer_quant_details[target][
  58. "input_activations"] = QuantizationArgs.parse_obj(
  59. quant_config.get("input_activations"))
  60. except Exception:
  61. layer_quant_details[target]["input_activations"] = None
  62. return cls(layer_quant_details=layer_quant_details,
  63. ignore=ignore,
  64. quant_format=quant_format)
  65. @classmethod
  66. def get_config_filenames(cls) -> List[str]:
  67. return []
  68. def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
  69. input_quant: BaseModel) -> bool:
  70. is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
  71. is_tensor = (weight_quant.strategy == input_quant.strategy ==
  72. QuantizationStrategy.TENSOR.value)
  73. is_symmetric = weight_quant.symmetric and input_quant.symmetric
  74. is_static = not weight_quant.dynamic and not input_quant.dynamic
  75. return is_8_bits and is_tensor and is_symmetric and is_static
  76. def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
  77. input_quant: BaseModel) -> bool:
  78. is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
  79. is_token_tensor = (weight_quant.strategy
  80. == QuantizationStrategy.TENSOR.value) and (
  81. input_quant.strategy
  82. == QuantizationStrategy.TOKEN.value)
  83. is_symmetric = weight_quant.symmetric and input_quant.symmetric
  84. is_dynamic = not weight_quant.dynamic and input_quant.dynamic
  85. return is_8_bits and is_token_tensor and is_symmetric and is_dynamic
  86. def _is_w4a16(self, weight_quant: BaseModel,
  87. input_quant: BaseModel) -> bool:
  88. input_quant_none = input_quant is None
  89. is_4_bits = weight_quant.num_bits == 4
  90. is_symmetric = weight_quant.symmetric
  91. is_static = not weight_quant.dynamic
  92. return is_4_bits and input_quant_none and is_symmetric and is_static
  93. def _get_schema(self, weight_quant: BaseModel,
  94. input_quant: BaseModel) -> "CompressedTensorsScheme":
  95. if self._is_w4a16(weight_quant, input_quant):
  96. if self.quant_format == CompressionFormat.marlin_24.value:
  97. return CompressedTensorsW4A16Sparse24(
  98. strategy=weight_quant.strategy,
  99. num_bits=weight_quant.num_bits,
  100. group_size=weight_quant.group_size)
  101. if self.quant_format == CompressionFormat.pack_quantized.value:
  102. return CompressedTensorsW4A16(
  103. num_bits=weight_quant.num_bits,
  104. strategy=weight_quant.strategy,
  105. group_size=weight_quant.group_size)
  106. if self.quant_format == CompressionFormat.int_quantized.value:
  107. if self._is_static_tensor_w8a8(weight_quant, input_quant):
  108. return CompressedTensorsW8A8StaticTensor()
  109. if self._is_dynamic_token_w8a8(weight_quant, input_quant):
  110. return CompressedTensorsW8A8DynamicToken()
  111. raise NotImplementedError(
  112. "No compressed-tensors compatible scheme was found.")
  113. def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
  114. layer_type_name = find_first_name_or_class_match(
  115. name="",
  116. module=layer,
  117. targets=self.layer_quant_details.keys(),
  118. check_contains=True)
  119. if layer_type_name is None:
  120. raise ValueError(f"Could not matching target for layer {layer}")
  121. layer_quant_details: Dict[str, Any] = self.layer_quant_details.get(
  122. layer_type_name, None)
  123. if layer_quant_details is None:
  124. raise ValueError(
  125. f"Could not find quantization details for {layer}.")
  126. return self._get_schema(
  127. weight_quant=layer_quant_details["weights"],
  128. input_quant=layer_quant_details["input_activations"])
  129. class CompressedTensorsLinearMethod(LinearMethodBase):
  130. def __init__(self, quantization_config: CompressedTensorsConfig):
  131. self.quantization_config = quantization_config
  132. def create_weights(self, layer: torch.nn.Module,
  133. input_size_per_partition: int,
  134. output_partition_sizes: List[int], input_size: int,
  135. output_size: int, params_dtype: torch.dtype,
  136. **extra_weight_attrs):
  137. """
  138. Use the CompressedTensorsScheme associated with each layer to create
  139. the necessary parameters for the layer. See LinearMethodBase for param
  140. details
  141. """
  142. weight_loader = extra_weight_attrs.get("weight_loader")
  143. scheme = self.quantization_config.get_scheme(layer=layer)
  144. scheme.create_weights(
  145. layer=layer,
  146. input_size=input_size,
  147. input_size_per_partition=input_size_per_partition,
  148. output_partition_sizes=output_partition_sizes,
  149. output_size=output_size,
  150. params_dtype=params_dtype,
  151. weight_loader=weight_loader)
  152. layer.scheme = scheme
  153. def apply(self,
  154. layer: torch.nn.Module,
  155. x: torch.Tensor,
  156. bias: Optional[torch.Tensor] = None):
  157. """
  158. Use the output of create_weights and the CompressedTensorsScheme
  159. associated with the layer to apply the forward pass with the
  160. layer input. See LinearMethodBase for param details
  161. """
  162. if bias is not None:
  163. raise ValueError("bias is not supported for this linear method")
  164. scheme = layer.scheme
  165. if scheme is None:
  166. raise ValueError("A scheme must be defined for each layer")
  167. return scheme.apply_weights(layer, x)