compressed_tensors.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from pydantic import BaseModel
  4. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  5. from aphrodite.platforms import current_platform
  6. from aphrodite.quantization.base_config import ( # noqa: E501
  7. QuantizationConfig, QuantizeMethodBase)
  8. from aphrodite.quantization.compressed_tensors.schemes import (
  9. W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
  10. CompressedTensorsScheme, CompressedTensorsUnquantized,
  11. CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
  12. CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
  13. CompressedTensorsWNA16)
  14. from aphrodite.quantization.compressed_tensors.utils import (
  15. CompressionFormat, QuantizationArgs, QuantizationStrategy,
  16. QuantizationType, find_matched_target, is_activation_quantization_format,
  17. should_ignore_layer)
  18. from aphrodite.quantization.kv_cache import BaseKVCacheMethod
  19. __all__ = ["CompressedTensorsLinearMethod"]
  20. class CompressedTensorsConfig(QuantizationConfig):
  21. def __init__(self,
  22. target_scheme_map: Dict[str, Any],
  23. ignore: List[str],
  24. quant_format: str,
  25. kv_cache_scheme: Optional[Dict[str, Any]] = None):
  26. self.ignore = ignore
  27. self.quant_format = quant_format
  28. # Map from [target -> scheme]
  29. self.target_scheme_map = target_scheme_map
  30. self.kv_cache_scheme = kv_cache_scheme
  31. def get_linear_method(self) -> "CompressedTensorsLinearMethod":
  32. return CompressedTensorsLinearMethod(self)
  33. def get_scaled_act_names(self) -> List[str]:
  34. return []
  35. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  36. return [torch.float16, torch.bfloat16]
  37. @classmethod
  38. def get_min_capability(cls) -> int:
  39. return 70
  40. def get_name(self) -> str:
  41. return "compressed_tensors"
  42. # TODO: do layer skipping though here
  43. # rather than though create_weights to match other methods
  44. def get_quant_method(
  45. self,
  46. layer: torch.nn.Module,
  47. prefix: str,
  48. ) -> Optional["QuantizeMethodBase"]:
  49. from aphrodite.attention.layer import (
  50. Attention) # Avoid circular import
  51. if isinstance(layer, LinearBase):
  52. return CompressedTensorsLinearMethod(self)
  53. if isinstance(layer, Attention):
  54. return CompressedTensorsKVCacheMethod(self)
  55. return None
  56. @classmethod
  57. def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
  58. target_scheme_map: Dict[str, Any] = dict()
  59. ignore: List[str] = config.get("ignore", None)
  60. quant_format: str = config.get("format", None)
  61. # The quant_config has multiple config_groups, each containing
  62. # an input_activations key with details about how the activations are
  63. # quantized, a weights key indicating how the weights are quantized,
  64. # and a list of targets under the `targets` key, dictating which
  65. # layers are impacted by the quantization details. The quantization
  66. # details follow the structure defined by the QuantizationArgs
  67. # pydantic model, which is used to verify the structure of the
  68. # quant_config and also store the details for later use.
  69. for _, quant_config in config["config_groups"].items():
  70. targets = quant_config.get("targets")
  71. for target in targets:
  72. target_scheme_map[target] = {}
  73. target_scheme_map[target][
  74. "weights"] = QuantizationArgs.parse_obj(
  75. quant_config.get("weights"))
  76. try:
  77. target_scheme_map[target][
  78. "input_activations"] = QuantizationArgs.parse_obj(
  79. quant_config.get("input_activations"))
  80. except Exception:
  81. target_scheme_map[target]["input_activations"] = None
  82. return cls(target_scheme_map=target_scheme_map,
  83. ignore=ignore,
  84. quant_format=quant_format,
  85. kv_cache_scheme=config.get("kv_cache_scheme"))
  86. @classmethod
  87. def get_config_filenames(cls) -> List[str]:
  88. return []
  89. def _check_scheme_supported(self,
  90. min_capability: int,
  91. error: bool = True) -> bool:
  92. capability = current_platform.get_device_capability()
  93. capability = capability[0] * 10 + capability[1]
  94. supported = capability >= min_capability
  95. if error and not supported:
  96. raise RuntimeError(
  97. "Quantization scheme is not supported for ",
  98. f"the current GPU. Min capability: {min_capability}. ",
  99. f"Current capability: {capability}.")
  100. return supported
  101. def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
  102. input_quant: BaseModel) -> bool:
  103. is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
  104. weight_strategy = (
  105. weight_quant.strategy == QuantizationStrategy.TENSOR.value
  106. or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
  107. is_tensor = (weight_strategy and input_quant.strategy
  108. == QuantizationStrategy.TENSOR.value)
  109. is_symmetric = weight_quant.symmetric and input_quant.symmetric
  110. is_static = not weight_quant.dynamic and not input_quant.dynamic
  111. return is_8_bits and is_tensor and is_symmetric and is_static
  112. def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
  113. input_quant: BaseModel) -> bool:
  114. is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
  115. weight_strategy = (
  116. weight_quant.strategy == QuantizationStrategy.TENSOR.value
  117. or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
  118. is_token = (weight_strategy and input_quant.strategy
  119. == QuantizationStrategy.TOKEN.value)
  120. is_symmetric = weight_quant.symmetric and input_quant.symmetric
  121. is_dynamic = not weight_quant.dynamic and input_quant.dynamic
  122. return is_8_bits and is_token and is_symmetric and is_dynamic
  123. def _is_fp8_w8a8(self, weight_quant: BaseModel,
  124. input_quant: BaseModel) -> bool:
  125. # Confirm weights and activations quantized.
  126. if weight_quant is None or input_quant is None:
  127. return False
  128. # Confirm weight scheme is supported.
  129. is_floating_point = (weight_quant.type == QuantizationType.FLOAT
  130. and input_quant.type == QuantizationType.FLOAT)
  131. is_symmetric_weight = weight_quant.symmetric
  132. is_static_weight = not weight_quant.dynamic
  133. is_per_tensor_or_channel_weight = (weight_quant.strategy in [
  134. QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
  135. ])
  136. if not (is_floating_point and is_symmetric_weight and is_static_weight
  137. and is_per_tensor_or_channel_weight):
  138. return False
  139. # Dynamic quantization is always supported if weights supported.
  140. if input_quant.dynamic:
  141. return True
  142. # Confirm activation scheme is supported.
  143. is_symmetric_activation = input_quant.symmetric
  144. is_per_tensor_activation = (
  145. input_quant.strategy == QuantizationStrategy.TENSOR)
  146. return is_symmetric_activation and is_per_tensor_activation
  147. def _is_fp8_w8a16(self, weight_quant: BaseModel,
  148. input_quant: BaseModel) -> bool:
  149. # Confirm weights quantized.
  150. if weight_quant is None:
  151. return False
  152. # Confirm we have floating points.
  153. if weight_quant.type != QuantizationType.FLOAT:
  154. return False
  155. # Confirm weight scheme is supported.
  156. is_symmetric_weight = weight_quant.symmetric
  157. is_static_weight = not weight_quant.dynamic
  158. is_per_tensor_or_channel_weight = (weight_quant.strategy in [
  159. QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
  160. ])
  161. if not (is_symmetric_weight and is_static_weight
  162. and is_per_tensor_or_channel_weight):
  163. return False
  164. # All conditions satisfied.
  165. return True
  166. def _is_wNa16_group_channel(self, weight_quant: BaseModel,
  167. input_quant: BaseModel) -> bool:
  168. input_quant_none = input_quant is None
  169. is_symmetric = weight_quant.symmetric
  170. is_channel_group = (
  171. weight_quant.strategy == QuantizationStrategy.CHANNEL.value
  172. or weight_quant.strategy == QuantizationStrategy.GROUP.value)
  173. is_static = not weight_quant.dynamic
  174. return (is_channel_group and input_quant_none and is_symmetric
  175. and is_static)
  176. def _get_scheme_from_parts(
  177. self, weight_quant: BaseModel,
  178. input_quant: BaseModel) -> "CompressedTensorsScheme":
  179. # Detect If Mixed Precision
  180. if self._is_wNa16_group_channel(weight_quant, input_quant):
  181. if (self.quant_format == CompressionFormat.marlin_24.value
  182. and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
  183. return CompressedTensorsW4A16Sparse24(
  184. strategy=weight_quant.strategy,
  185. num_bits=weight_quant.num_bits,
  186. group_size=weight_quant.group_size)
  187. if (self.quant_format == CompressionFormat.pack_quantized.value
  188. and weight_quant.num_bits in WNA16_SUPPORTED_BITS):
  189. return CompressedTensorsWNA16(
  190. num_bits=weight_quant.num_bits,
  191. strategy=weight_quant.strategy,
  192. group_size=weight_quant.group_size,
  193. actorder=weight_quant.actorder)
  194. # Detect If Activation Quantization.
  195. # TODO @dsikka: clean-up conditions
  196. if is_activation_quantization_format(self.quant_format):
  197. if self._is_fp8_w8a8(weight_quant, input_quant):
  198. is_fp8_w8a8_supported = self._check_scheme_supported(
  199. CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
  200. if is_fp8_w8a8_supported:
  201. return CompressedTensorsW8A8Fp8(
  202. strategy=weight_quant.strategy,
  203. is_static_input_scheme=(input_quant
  204. and not input_quant.dynamic))
  205. else:
  206. return CompressedTensorsW8A16Fp8(
  207. strategy=weight_quant.strategy,
  208. is_static_input_scheme=(input_quant
  209. and not input_quant.dynamic))
  210. if self._is_fp8_w8a16(weight_quant, input_quant):
  211. return CompressedTensorsW8A16Fp8(
  212. strategy=weight_quant.strategy,
  213. is_static_input_scheme=(input_quant
  214. and not input_quant.dynamic))
  215. if self._is_static_tensor_w8a8(weight_quant, input_quant):
  216. return CompressedTensorsW8A8Int8(
  217. strategy=weight_quant.strategy,
  218. is_static_input_scheme=True)
  219. if self._is_dynamic_token_w8a8(weight_quant, input_quant):
  220. return CompressedTensorsW8A8Int8(
  221. strategy=weight_quant.strategy,
  222. is_static_input_scheme=False)
  223. raise NotImplementedError(
  224. "No compressed-tensors compatible scheme was found.")
  225. def get_scheme(
  226. self,
  227. layer: torch.nn.Module,
  228. layer_name: Optional[str] = None) -> "CompressedTensorsScheme":
  229. """
  230. compressed-tensors supports non uniform in the following way:
  231. ignore: List of layer_names or nn.Module names to be ignored.
  232. targets of config_groups: There can be N config_groups which each
  233. have a quantization scheme. Each config_group has a list of targets
  234. which can be a full layer_name, a regex for a layer_name, or
  235. an nn.Module name.
  236. We first check whether a layer is in the ignore group and use
  237. CompressedTensorsUnquantized (i.e. fp16/bf16) scheme for the layer
  238. We then detect whether a layer_name is found in any target and
  239. use the quantization scheme corresponding to the matched target
  240. to select the CompressedTensorsScheme used for infernece.
  241. """
  242. # Check if the layer is skipped for quantization.
  243. # TODO: support module names
  244. if should_ignore_layer(layer_name, ignore=self.ignore):
  245. return CompressedTensorsUnquantized()
  246. # Find the "target" in the compressed-tensors config
  247. # that our layer conforms to.
  248. # TODO: add compressed-tensors as dep
  249. # so we do not have to re-write these functions
  250. matched_target = find_matched_target(
  251. layer_name=layer_name,
  252. module=layer,
  253. targets=self.target_scheme_map.keys())
  254. # Find the quant_scheme
  255. scheme_dict = self.target_scheme_map[matched_target]
  256. scheme = self._get_scheme_from_parts(
  257. weight_quant=scheme_dict["weights"],
  258. input_quant=scheme_dict["input_activations"])
  259. # Raise error if device does not support the scheme
  260. # (e.g. fp8 needs ada lovelace)
  261. self._check_scheme_supported(scheme.get_min_capability())
  262. return scheme
  263. class CompressedTensorsLinearMethod(LinearMethodBase):
  264. def __init__(self, quantization_config: CompressedTensorsConfig):
  265. self.quantization_config = quantization_config
  266. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  267. layer.scheme.process_weights_after_loading(layer)
  268. def create_weights(self, layer: torch.nn.Module,
  269. input_size_per_partition: int,
  270. output_partition_sizes: List[int], input_size: int,
  271. output_size: int, params_dtype: torch.dtype,
  272. **extra_weight_attrs):
  273. """
  274. Use the CompressedTensorsScheme associated with each layer to create
  275. the necessary parameters for the layer. See LinearMethodBase for param
  276. details
  277. """
  278. weight_loader = extra_weight_attrs.get("weight_loader")
  279. layer_name = extra_weight_attrs.get("prefix")
  280. scheme = self.quantization_config.get_scheme(layer, layer_name)
  281. scheme.create_weights(
  282. layer=layer,
  283. input_size=input_size,
  284. input_size_per_partition=input_size_per_partition,
  285. output_partition_sizes=output_partition_sizes,
  286. output_size=output_size,
  287. params_dtype=params_dtype,
  288. weight_loader=weight_loader)
  289. layer.scheme = scheme
  290. def apply(self,
  291. layer: torch.nn.Module,
  292. x: torch.Tensor,
  293. bias: Optional[torch.Tensor] = None):
  294. """
  295. Use the output of create_weights and the CompressedTensorsScheme
  296. associated with the layer to apply the forward pass with the
  297. layer input. See LinearMethodBase for param details
  298. """
  299. scheme = layer.scheme
  300. if scheme is None:
  301. raise ValueError("A scheme must be defined for each layer")
  302. return scheme.apply_weights(layer, x, bias=bias)
  303. class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
  304. """
  305. Supports loading kv-cache scaling factors from compressed-tensors
  306. checkpoints.
  307. """
  308. def __init__(self, quant_config: CompressedTensorsConfig):
  309. self.validate_kv_cache_scheme(quant_config.kv_cache_scheme)
  310. super().__init__(quant_config)
  311. @staticmethod
  312. def validate_kv_cache_scheme(kv_cache_scheme: Optional[Dict[str, Any]]):
  313. """
  314. Validator for the kv cache scheme. Useful for controlling the
  315. kv cache quantization schemes, that are being supported in Aphrodite
  316. :param kv_cache_scheme: the compressed-tensors kv cache scheme
  317. """
  318. if kv_cache_scheme is None:
  319. return
  320. type_ = kv_cache_scheme.get("type")
  321. num_bits = kv_cache_scheme.get("num_bits")
  322. if type_ != "float" and num_bits != 8:
  323. raise NotImplementedError(
  324. "Currently supported kv cache quantization is "
  325. "num_bits=8, type=float, however "
  326. f"received num_bits={num_bits}, type={type_}")
  327. strategy = kv_cache_scheme.get("strategy")
  328. if strategy != "tensor":
  329. raise NotImplementedError(
  330. "Only support per-tensor scaling factor "
  331. "for compressed-tensors KV cache. "
  332. f"Expected strategy: tensor, found strategy: {strategy}")
  333. is_symmetric = kv_cache_scheme.get("symmetric")
  334. if not is_symmetric:
  335. raise NotImplementedError(
  336. "Only support symmetric scaling factor "
  337. "for compressed-tensors KV cache. "
  338. f"However found symmetric: {is_symmetric}")