compressed_tensors.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from pydantic import BaseModel
  4. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  5. from aphrodite.quantization.base_config import QuantizationConfig # noqa: E501
  6. from aphrodite.quantization.compressed_tensors.schemes import (
  7. CompressedTensorsScheme, CompressedTensorsW4A16,
  8. CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8DynamicToken,
  9. CompressedTensorsW8A8StaticTensor)
  10. from aphrodite.quantization.compressed_tensors.utils import (
  11. CompressionFormat, QuantizationArgs, QuantizationStrategy,
  12. find_first_name_or_class_match)
  13. class CompressedTensorsConfig(QuantizationConfig):
  14. def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str],
  15. quant_format: str):
  16. self.ignore = ignore
  17. self.layer_quant_details = layer_quant_details
  18. self.quant_format = quant_format
  19. def get_linear_method(self) -> "CompressedTensorsLinearMethod":
  20. return CompressedTensorsLinearMethod(self)
  21. def get_scaled_act_names(self) -> List[str]:
  22. return []
  23. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  24. return [torch.float16, torch.bfloat16]
  25. # Need to figure it out
  26. def get_min_capability(self) -> int:
  27. return 60
  28. def get_name(self) -> str:
  29. return "compressed_tensors"
  30. def get_quant_method(
  31. self, layer: torch.nn.Module
  32. ) -> Optional["CompressedTensorsLinearMethod"]:
  33. if isinstance(layer, LinearBase):
  34. return CompressedTensorsLinearMethod(self)
  35. return None
  36. @classmethod
  37. def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
  38. layer_quant_details: Dict[str, Any] = dict()
  39. ignore: List[str] = config.get("ignore", None)
  40. quant_format: str = config.get("format", None)
  41. # The quant_config has multiple config_groups, each containing
  42. # an input_activations key with details about how the activations are
  43. # quantized, a weights key indicating how the weights are quantized,
  44. # and a list of targets under the `targets` key, dictating which
  45. # layers are impacted by the quantization details. The quantization
  46. # details follow the structure defined by the QuantizationArgs
  47. # pydantic model, which is used to verify the structure of the
  48. # quant_config and also store the details for later use.
  49. for key, quant_config in config["config_groups"].items():
  50. targets = quant_config.get("targets")
  51. for target in targets:
  52. layer_quant_details[target] = {}
  53. layer_quant_details[target][
  54. "weights"] = QuantizationArgs.parse_obj(
  55. quant_config.get("weights"))
  56. try:
  57. layer_quant_details[target][
  58. "input_activations"] = QuantizationArgs.parse_obj(
  59. quant_config.get("input_activations"))
  60. except Exception:
  61. layer_quant_details[target]["input_activations"] = None
  62. return cls(layer_quant_details=layer_quant_details,
  63. ignore=ignore,
  64. quant_format=quant_format)
  65. @classmethod
  66. def get_config_filenames(cls) -> List[str]:
  67. return []
  68. def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
  69. input_quant: BaseModel) -> bool:
  70. is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
  71. weight_strategy = (
  72. weight_quant.strategy == QuantizationStrategy.TENSOR.value
  73. or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
  74. is_tensor = (weight_strategy and input_quant.strategy
  75. == QuantizationStrategy.TENSOR.value)
  76. is_symmetric = weight_quant.symmetric and input_quant.symmetric
  77. is_static = not weight_quant.dynamic and not input_quant.dynamic
  78. return is_8_bits and is_tensor and is_symmetric and is_static
  79. def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
  80. input_quant: BaseModel) -> bool:
  81. is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
  82. weight_strategy = (
  83. weight_quant.strategy == QuantizationStrategy.TENSOR.value
  84. or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
  85. is_token = (weight_strategy and input_quant.strategy
  86. == QuantizationStrategy.TOKEN.value)
  87. is_symmetric = weight_quant.symmetric and input_quant.symmetric
  88. is_dynamic = not weight_quant.dynamic and input_quant.dynamic
  89. return is_8_bits and is_token and is_symmetric and is_dynamic
  90. def _is_w4a16(self, weight_quant: BaseModel,
  91. input_quant: BaseModel) -> bool:
  92. input_quant_none = input_quant is None
  93. is_4_bits = weight_quant.num_bits == 4
  94. is_symmetric = weight_quant.symmetric
  95. is_static = not weight_quant.dynamic
  96. return is_4_bits and input_quant_none and is_symmetric and is_static
  97. def _get_schema(self, weight_quant: BaseModel,
  98. input_quant: BaseModel) -> "CompressedTensorsScheme":
  99. if self._is_w4a16(weight_quant, input_quant):
  100. if self.quant_format == CompressionFormat.marlin_24.value:
  101. return CompressedTensorsW4A16Sparse24(
  102. strategy=weight_quant.strategy,
  103. num_bits=weight_quant.num_bits,
  104. group_size=weight_quant.group_size)
  105. if self.quant_format == CompressionFormat.pack_quantized.value:
  106. return CompressedTensorsW4A16(
  107. num_bits=weight_quant.num_bits,
  108. strategy=weight_quant.strategy,
  109. group_size=weight_quant.group_size)
  110. if self.quant_format == CompressionFormat.int_quantized.value:
  111. if self._is_static_tensor_w8a8(weight_quant, input_quant):
  112. return CompressedTensorsW8A8StaticTensor(
  113. strategy=weight_quant.strategy)
  114. if self._is_dynamic_token_w8a8(weight_quant, input_quant):
  115. return CompressedTensorsW8A8DynamicToken(
  116. strategy=weight_quant.strategy)
  117. raise NotImplementedError(
  118. "No compressed-tensors compatible scheme was found.")
  119. def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
  120. layer_type_name = find_first_name_or_class_match(
  121. name="",
  122. module=layer,
  123. targets=self.layer_quant_details.keys(),
  124. check_contains=True)
  125. if layer_type_name is None:
  126. raise ValueError(f"Could not matching target for layer {layer}")
  127. layer_quant_details: Dict[str, Any] = self.layer_quant_details.get(
  128. layer_type_name, None)
  129. if layer_quant_details is None:
  130. raise ValueError(
  131. f"Could not find quantization details for {layer}.")
  132. return self._get_schema(
  133. weight_quant=layer_quant_details["weights"],
  134. input_quant=layer_quant_details["input_activations"])
  135. class CompressedTensorsLinearMethod(LinearMethodBase):
  136. def __init__(self, quantization_config: CompressedTensorsConfig):
  137. self.quantization_config = quantization_config
  138. def create_weights(self, layer: torch.nn.Module,
  139. input_size_per_partition: int,
  140. output_partition_sizes: List[int], input_size: int,
  141. output_size: int, params_dtype: torch.dtype,
  142. **extra_weight_attrs):
  143. """
  144. Use the CompressedTensorsScheme associated with each layer to create
  145. the necessary parameters for the layer. See LinearMethodBase for param
  146. details
  147. """
  148. weight_loader = extra_weight_attrs.get("weight_loader")
  149. scheme = self.quantization_config.get_scheme(layer=layer)
  150. scheme.create_weights(
  151. layer=layer,
  152. input_size=input_size,
  153. input_size_per_partition=input_size_per_partition,
  154. output_partition_sizes=output_partition_sizes,
  155. output_size=output_size,
  156. params_dtype=params_dtype,
  157. weight_loader=weight_loader)
  158. layer.scheme = scheme
  159. def apply(self,
  160. layer: torch.nn.Module,
  161. x: torch.Tensor,
  162. bias: Optional[torch.Tensor] = None):
  163. """
  164. Use the output of create_weights and the CompressedTensorsScheme
  165. associated with the layer to apply the forward pass with the
  166. layer input. See LinearMethodBase for param details
  167. """
  168. if bias is not None:
  169. raise ValueError("bias is not supported for this linear method")
  170. scheme = layer.scheme
  171. if scheme is None:
  172. raise ValueError("A scheme must be defined for each layer")
  173. return scheme.apply_weights(layer, x)