compressed_tensors.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from pydantic import BaseModel
  4. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  5. from aphrodite.platforms import current_platform
  6. from aphrodite.quantization.base_config import QuantizationConfig # noqa: E501
  7. from aphrodite.quantization.compressed_tensors.schemes import (
  8. W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
  9. CompressedTensorsScheme, CompressedTensorsUnquantized,
  10. CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
  11. CompressedTensorsW8A8Int8, CompressedTensorsWNA16)
  12. from aphrodite.quantization.compressed_tensors.utils import (
  13. CompressionFormat, QuantizationArgs, QuantizationStrategy,
  14. QuantizationType, find_matched_target, is_activation_quantization_format,
  15. should_ignore_layer)
  16. class CompressedTensorsConfig(QuantizationConfig):
  17. def __init__(self, target_scheme_map: Dict[str, Any], ignore: List[str],
  18. quant_format: str):
  19. self.ignore = ignore
  20. self.quant_format = quant_format
  21. # Map from [target -> scheme]
  22. self.target_scheme_map = target_scheme_map
  23. def get_linear_method(self) -> "CompressedTensorsLinearMethod":
  24. return CompressedTensorsLinearMethod(self)
  25. def get_scaled_act_names(self) -> List[str]:
  26. return []
  27. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  28. return [torch.float16, torch.bfloat16]
  29. @classmethod
  30. def get_min_capability(cls) -> int:
  31. return 70
  32. def get_name(self) -> str:
  33. return "compressed_tensors"
  34. # TODO: do layer skipping though here
  35. # rather than though create_weights to match other methods
  36. def get_quant_method(
  37. self,
  38. layer: torch.nn.Module,
  39. prefix: str,
  40. ) -> Optional["CompressedTensorsLinearMethod"]:
  41. if isinstance(layer, LinearBase):
  42. return CompressedTensorsLinearMethod(self)
  43. return None
  44. @classmethod
  45. def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
  46. target_scheme_map: Dict[str, Any] = dict()
  47. ignore: List[str] = config.get("ignore", None)
  48. quant_format: str = config.get("format", None)
  49. # The quant_config has multiple config_groups, each containing
  50. # an input_activations key with details about how the activations are
  51. # quantized, a weights key indicating how the weights are quantized,
  52. # and a list of targets under the `targets` key, dictating which
  53. # layers are impacted by the quantization details. The quantization
  54. # details follow the structure defined by the QuantizationArgs
  55. # pydantic model, which is used to verify the structure of the
  56. # quant_config and also store the details for later use.
  57. for _, quant_config in config["config_groups"].items():
  58. targets = quant_config.get("targets")
  59. for target in targets:
  60. target_scheme_map[target] = {}
  61. target_scheme_map[target][
  62. "weights"] = QuantizationArgs.parse_obj(
  63. quant_config.get("weights"))
  64. try:
  65. target_scheme_map[target][
  66. "input_activations"] = QuantizationArgs.parse_obj(
  67. quant_config.get("input_activations"))
  68. except Exception:
  69. target_scheme_map[target]["input_activations"] = None
  70. return cls(target_scheme_map=target_scheme_map,
  71. ignore=ignore,
  72. quant_format=quant_format)
  73. @classmethod
  74. def get_config_filenames(cls) -> List[str]:
  75. return []
  76. def _check_scheme_supported(self, min_capability: int):
  77. capability = current_platform.get_device_capability()
  78. capability = capability[0] * 10 + capability[1]
  79. if capability < min_capability:
  80. raise RuntimeError(
  81. "Quantization scheme is not supported for ",
  82. f"the current GPU. Min capability: {min_capability}. ",
  83. f"Current capability: {capability}.")
  84. def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
  85. input_quant: BaseModel) -> bool:
  86. is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
  87. weight_strategy = (
  88. weight_quant.strategy == QuantizationStrategy.TENSOR.value
  89. or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
  90. is_tensor = (weight_strategy and input_quant.strategy
  91. == QuantizationStrategy.TENSOR.value)
  92. is_symmetric = weight_quant.symmetric and input_quant.symmetric
  93. is_static = not weight_quant.dynamic and not input_quant.dynamic
  94. return is_8_bits and is_tensor and is_symmetric and is_static
  95. def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
  96. input_quant: BaseModel) -> bool:
  97. is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
  98. weight_strategy = (
  99. weight_quant.strategy == QuantizationStrategy.TENSOR.value
  100. or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
  101. is_token = (weight_strategy and input_quant.strategy
  102. == QuantizationStrategy.TOKEN.value)
  103. is_symmetric = weight_quant.symmetric and input_quant.symmetric
  104. is_dynamic = not weight_quant.dynamic and input_quant.dynamic
  105. return is_8_bits and is_token and is_symmetric and is_dynamic
  106. def _is_fp8_w8a8(self, weight_quant: BaseModel,
  107. input_quant: BaseModel) -> bool:
  108. # Confirm weights and activations quantized.
  109. if weight_quant is None or input_quant is None:
  110. return False
  111. # Confirm we have floating points.
  112. if not (weight_quant.type == QuantizationType.FLOAT
  113. and input_quant.type == QuantizationType.FLOAT):
  114. return False
  115. # Confirm weight scheme is supported.
  116. is_symmetric_weight = weight_quant.symmetric
  117. is_static_weight = not weight_quant.dynamic
  118. is_per_tensor_or_channel_weight = (weight_quant.strategy in [
  119. QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
  120. ])
  121. if not (is_symmetric_weight and is_static_weight
  122. and is_per_tensor_or_channel_weight):
  123. return False
  124. # Dynamic quantization is always supported if weights supported.
  125. if input_quant.dynamic:
  126. return True
  127. # Confirm activation scheme is supported.
  128. is_symmetric_activation = input_quant.symmetric
  129. is_per_tensor_activation = (
  130. input_quant.strategy == QuantizationStrategy.TENSOR)
  131. if not (is_symmetric_activation and is_per_tensor_activation):
  132. return False
  133. # All conditions satisfied.
  134. return True
  135. def _is_wNa16_group_channel(self, weight_quant: BaseModel,
  136. input_quant: BaseModel) -> bool:
  137. input_quant_none = input_quant is None
  138. is_symmetric = weight_quant.symmetric
  139. is_channel_group = (
  140. weight_quant.strategy == QuantizationStrategy.CHANNEL.value
  141. or weight_quant.strategy == QuantizationStrategy.GROUP.value)
  142. is_static = not weight_quant.dynamic
  143. return (is_channel_group and input_quant_none and is_symmetric
  144. and is_static)
  145. def _get_scheme_from_parts(
  146. self, weight_quant: BaseModel,
  147. input_quant: BaseModel) -> "CompressedTensorsScheme":
  148. # Detect If Mixed Precision
  149. if self._is_wNa16_group_channel(weight_quant, input_quant):
  150. if (self.quant_format == CompressionFormat.marlin_24.value
  151. and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
  152. return CompressedTensorsW4A16Sparse24(
  153. strategy=weight_quant.strategy,
  154. num_bits=weight_quant.num_bits,
  155. group_size=weight_quant.group_size)
  156. if (self.quant_format == CompressionFormat.pack_quantized.value
  157. and weight_quant.num_bits in WNA16_SUPPORTED_BITS):
  158. return CompressedTensorsWNA16(
  159. num_bits=weight_quant.num_bits,
  160. strategy=weight_quant.strategy,
  161. group_size=weight_quant.group_size)
  162. # Detect If Activation Quantization.
  163. if is_activation_quantization_format(self.quant_format):
  164. if self._is_fp8_w8a8(weight_quant, input_quant):
  165. return CompressedTensorsW8A8Fp8(
  166. strategy=weight_quant.strategy,
  167. is_static_input_scheme=(not input_quant.dynamic))
  168. if self._is_static_tensor_w8a8(weight_quant, input_quant):
  169. return CompressedTensorsW8A8Int8(
  170. strategy=weight_quant.strategy,
  171. is_static_input_scheme=True)
  172. if self._is_dynamic_token_w8a8(weight_quant, input_quant):
  173. return CompressedTensorsW8A8Int8(
  174. strategy=weight_quant.strategy,
  175. is_static_input_scheme=False)
  176. raise NotImplementedError(
  177. "No compressed-tensors compatible scheme was found.")
  178. def get_scheme(
  179. self,
  180. layer: torch.nn.Module,
  181. layer_name: Optional[str] = None) -> "CompressedTensorsScheme":
  182. """
  183. compressed-tensors supports non uniform in the following way:
  184. ignore: List of layer_names or nn.Module names to be ignored.
  185. targets of config_groups: There can be N config_groups which each
  186. have a quantization scheme. Each config_group has a list of targets
  187. which can be a full layer_name, a regex for a layer_name, or
  188. an nn.Module name.
  189. We first check whether a layer is in the ignore group and use
  190. CompressedTensorsUnquantized (i.e. fp16/bf16) scheme for the layer
  191. We then detect whether a layer_name is found in any target and
  192. use the quantization scheme corresponding to the matched target
  193. to select the CompressedTensorsScheme used for infernece.
  194. """
  195. # Check if the layer is skipped for quantization.
  196. # TODO: support module names
  197. if should_ignore_layer(layer_name, ignore=self.ignore):
  198. return CompressedTensorsUnquantized()
  199. # Find the "target" in the compressed-tensors config
  200. # that our layer conforms to.
  201. # TODO: add compressed-tensors as dep
  202. # so we do not have to re-write these functions
  203. matched_target = find_matched_target(
  204. layer_name=layer_name,
  205. module=layer,
  206. targets=self.target_scheme_map.keys())
  207. # Find the quant_scheme
  208. scheme = self.target_scheme_map[matched_target]
  209. return self._get_scheme_from_parts(
  210. weight_quant=scheme["weights"],
  211. input_quant=scheme["input_activations"])
  212. # Raise error if device does not support the scheme
  213. # (e.g. fp8 needs ada lovelace)
  214. self._check_scheme_supported(scheme.get_min_capability())
  215. return scheme
  216. class CompressedTensorsLinearMethod(LinearMethodBase):
  217. def __init__(self, quantization_config: CompressedTensorsConfig):
  218. self.quantization_config = quantization_config
  219. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  220. layer.scheme.process_weights_after_loading(layer)
  221. def create_weights(self, layer: torch.nn.Module,
  222. input_size_per_partition: int,
  223. output_partition_sizes: List[int], input_size: int,
  224. output_size: int, params_dtype: torch.dtype,
  225. **extra_weight_attrs):
  226. """
  227. Use the CompressedTensorsScheme associated with each layer to create
  228. the necessary parameters for the layer. See LinearMethodBase for param
  229. details
  230. """
  231. weight_loader = extra_weight_attrs.get("weight_loader")
  232. layer_name = extra_weight_attrs.get("prefix")
  233. scheme = self.quantization_config.get_scheme(layer, layer_name)
  234. scheme.create_weights(
  235. layer=layer,
  236. input_size=input_size,
  237. input_size_per_partition=input_size_per_partition,
  238. output_partition_sizes=output_partition_sizes,
  239. output_size=output_size,
  240. params_dtype=params_dtype,
  241. weight_loader=weight_loader)
  242. layer.scheme = scheme
  243. def apply(self,
  244. layer: torch.nn.Module,
  245. x: torch.Tensor,
  246. bias: Optional[torch.Tensor] = None):
  247. """
  248. Use the output of create_weights and the CompressedTensorsScheme
  249. associated with the layer to apply the forward pass with the
  250. layer input. See LinearMethodBase for param details
  251. """
  252. scheme = layer.scheme
  253. if scheme is None:
  254. raise ValueError("A scheme must be defined for each layer")
  255. return scheme.apply_weights(layer, x, bias=bias)