compressed_tensors.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from pydantic import BaseModel
  4. from aphrodite.modeling.layers.fused_moe import FusedMoE
  5. from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
  6. UnquantizedLinearMethod)
  7. from aphrodite.platforms import current_platform
  8. from aphrodite.quantization.base_config import ( # noqa: E501
  9. QuantizationConfig, QuantizeMethodBase)
  10. from aphrodite.quantization.compressed_tensors.compressed_tensors_moe import (
  11. CompressedTensorsMoEMethod)
  12. from aphrodite.quantization.compressed_tensors.schemes import (
  13. W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
  14. CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
  15. CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
  16. CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
  17. from aphrodite.quantization.compressed_tensors.utils import (
  18. CompressionFormat, QuantizationArgs, QuantizationStrategy,
  19. QuantizationType, find_matched_target, is_activation_quantization_format,
  20. should_ignore_layer)
  21. from aphrodite.quantization.kv_cache import BaseKVCacheMethod
  22. __all__ = ["CompressedTensorsLinearMethod"]
  23. class CompressedTensorsConfig(QuantizationConfig):
  24. def __init__(self,
  25. target_scheme_map: Dict[str, Any],
  26. ignore: List[str],
  27. quant_format: str,
  28. kv_cache_scheme: Optional[Dict[str, Any]] = None):
  29. self.ignore = ignore
  30. self.quant_format = quant_format
  31. # Map from [target -> scheme]
  32. self.target_scheme_map = target_scheme_map
  33. self.kv_cache_scheme = kv_cache_scheme
  34. def get_linear_method(self) -> "CompressedTensorsLinearMethod":
  35. return CompressedTensorsLinearMethod(self)
  36. def get_scaled_act_names(self) -> List[str]:
  37. return []
  38. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  39. return [torch.float16, torch.bfloat16]
  40. @classmethod
  41. def get_min_capability(cls) -> int:
  42. return 70
  43. def get_name(self) -> str:
  44. return "compressed_tensors"
  45. def get_quant_method(
  46. self,
  47. layer: torch.nn.Module,
  48. prefix: str,
  49. ) -> Optional["QuantizeMethodBase"]:
  50. from aphrodite.attention.layer import (
  51. Attention) # Avoid circular import
  52. # Check if the layer is skipped for quantization.
  53. # TODO: support module names
  54. if should_ignore_layer(prefix, ignore=self.ignore):
  55. return UnquantizedLinearMethod()
  56. if isinstance(layer, LinearBase):
  57. scheme = self.get_scheme(layer=layer, layer_name=prefix)
  58. layer.scheme = scheme
  59. return CompressedTensorsLinearMethod(self)
  60. if isinstance(layer, Attention):
  61. return CompressedTensorsKVCacheMethod(self)
  62. if isinstance(layer, FusedMoE):
  63. return CompressedTensorsMoEMethod(self)
  64. return None
  65. @classmethod
  66. def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
  67. target_scheme_map: Dict[str, Any] = dict()
  68. ignore: List[str] = config.get("ignore", None)
  69. quant_format: str = config.get("format", None)
  70. # The quant_config has multiple config_groups, each containing
  71. # an input_activations key with details about how the activations are
  72. # quantized, a weights key indicating how the weights are quantized,
  73. # and a list of targets under the `targets` key, dictating which
  74. # layers are impacted by the quantization details. The quantization
  75. # details follow the structure defined by the QuantizationArgs
  76. # pydantic model, which is used to verify the structure of the
  77. # quant_config and also store the details for later use.
  78. for _, quant_config in config["config_groups"].items():
  79. targets = quant_config.get("targets")
  80. for target in targets:
  81. target_scheme_map[target] = {}
  82. target_scheme_map[target][
  83. "weights"] = QuantizationArgs.parse_obj(
  84. quant_config.get("weights"))
  85. try:
  86. target_scheme_map[target][
  87. "input_activations"] = QuantizationArgs.parse_obj(
  88. quant_config.get("input_activations"))
  89. except Exception:
  90. target_scheme_map[target]["input_activations"] = None
  91. return cls(target_scheme_map=target_scheme_map,
  92. ignore=ignore,
  93. quant_format=quant_format,
  94. kv_cache_scheme=config.get("kv_cache_scheme"))
  95. @classmethod
  96. def get_config_filenames(cls) -> List[str]:
  97. return []
  98. def _check_scheme_supported(self,
  99. min_capability: int,
  100. error: bool = True) -> bool:
  101. capability = current_platform.get_device_capability() # type: ignore
  102. if capability is not None:
  103. capability = capability[0] * 10 + capability[1]
  104. supported = capability >= min_capability
  105. if error and not supported:
  106. raise RuntimeError(
  107. "Quantization scheme is not supported for ",
  108. f"the current GPU. Min capability: {min_capability}. ",
  109. f"Current capability: {capability}.")
  110. return supported
  111. else:
  112. return False
  113. def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
  114. input_quant: BaseModel) -> bool:
  115. is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
  116. weight_strategy = (
  117. weight_quant.strategy == QuantizationStrategy.TENSOR.value
  118. or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
  119. is_tensor = (weight_strategy and input_quant.strategy
  120. == QuantizationStrategy.TENSOR.value)
  121. is_symmetric = weight_quant.symmetric and input_quant.symmetric
  122. is_static = not weight_quant.dynamic and not input_quant.dynamic
  123. return is_8_bits and is_tensor and is_symmetric and is_static
  124. def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
  125. input_quant: BaseModel) -> bool:
  126. is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
  127. weight_strategy = (
  128. weight_quant.strategy == QuantizationStrategy.TENSOR.value
  129. or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
  130. is_token = (weight_strategy and input_quant.strategy
  131. == QuantizationStrategy.TOKEN.value)
  132. is_symmetric = weight_quant.symmetric and input_quant.symmetric
  133. is_dynamic = not weight_quant.dynamic and input_quant.dynamic
  134. return is_8_bits and is_token and is_symmetric and is_dynamic
  135. def _is_fp8_w8a8(self, weight_quant: BaseModel,
  136. input_quant: BaseModel) -> bool:
  137. # Confirm weights and activations quantized.
  138. if weight_quant is None or input_quant is None:
  139. return False
  140. # Confirm weight scheme is supported.
  141. is_floating_point = (weight_quant.type == QuantizationType.FLOAT
  142. and input_quant.type == QuantizationType.FLOAT)
  143. is_symmetric_weight = weight_quant.symmetric
  144. is_static_weight = not weight_quant.dynamic
  145. is_per_tensor_or_channel_weight = (weight_quant.strategy in [
  146. QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
  147. ])
  148. if not (is_floating_point and is_symmetric_weight and is_static_weight
  149. and is_per_tensor_or_channel_weight):
  150. return False
  151. # Dynamic quantization is always supported if weights supported.
  152. if input_quant.dynamic:
  153. return True
  154. # Confirm activation scheme is supported.
  155. is_symmetric_activation = input_quant.symmetric
  156. is_per_tensor_activation = (
  157. input_quant.strategy == QuantizationStrategy.TENSOR)
  158. return is_symmetric_activation and is_per_tensor_activation
  159. def _is_fp8_w8a16(self, weight_quant: BaseModel,
  160. input_quant: BaseModel) -> bool:
  161. # Confirm weights quantized.
  162. if weight_quant is None:
  163. return False
  164. # Confirm we have floating points.
  165. if weight_quant.type != QuantizationType.FLOAT:
  166. return False
  167. # Confirm weight scheme is supported.
  168. is_symmetric_weight = weight_quant.symmetric
  169. is_static_weight = not weight_quant.dynamic
  170. is_per_tensor_or_channel_weight = (weight_quant.strategy in [
  171. QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
  172. ])
  173. if not (is_symmetric_weight and is_static_weight
  174. and is_per_tensor_or_channel_weight):
  175. return False
  176. # All conditions satisfied.
  177. return True
  178. def _is_wNa16_group_channel(self, weight_quant: BaseModel,
  179. input_quant: BaseModel) -> bool:
  180. input_quant_none = input_quant is None
  181. is_symmetric = weight_quant.symmetric
  182. is_channel_group = (
  183. weight_quant.strategy == QuantizationStrategy.CHANNEL.value
  184. or weight_quant.strategy == QuantizationStrategy.GROUP.value)
  185. is_static = not weight_quant.dynamic
  186. return (is_channel_group and input_quant_none and is_symmetric
  187. and is_static)
  188. def _get_scheme_from_parts(
  189. self, weight_quant: BaseModel,
  190. input_quant: BaseModel) -> "CompressedTensorsScheme":
  191. # Detect If Mixed Precision
  192. if self._is_wNa16_group_channel(weight_quant, input_quant):
  193. if (self.quant_format == CompressionFormat.marlin_24.value
  194. and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
  195. return CompressedTensorsW4A16Sparse24(
  196. strategy=weight_quant.strategy,
  197. num_bits=weight_quant.num_bits,
  198. group_size=weight_quant.group_size)
  199. if (self.quant_format == CompressionFormat.pack_quantized.value
  200. and weight_quant.num_bits in WNA16_SUPPORTED_BITS):
  201. return CompressedTensorsWNA16(
  202. num_bits=weight_quant.num_bits,
  203. strategy=weight_quant.strategy,
  204. group_size=weight_quant.group_size,
  205. actorder=weight_quant.actorder)
  206. # Detect If Activation Quantization.
  207. # TODO @dsikka: clean-up conditions
  208. if is_activation_quantization_format(self.quant_format):
  209. if self._is_fp8_w8a8(weight_quant, input_quant):
  210. is_fp8_w8a8_supported = self._check_scheme_supported(
  211. CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
  212. if is_fp8_w8a8_supported:
  213. return CompressedTensorsW8A8Fp8(
  214. strategy=weight_quant.strategy,
  215. is_static_input_scheme=(input_quant
  216. and not input_quant.dynamic))
  217. else:
  218. return CompressedTensorsW8A16Fp8(
  219. strategy=weight_quant.strategy,
  220. is_static_input_scheme=(input_quant
  221. and not input_quant.dynamic))
  222. if self._is_fp8_w8a16(weight_quant, input_quant):
  223. return CompressedTensorsW8A16Fp8(
  224. strategy=weight_quant.strategy,
  225. is_static_input_scheme=(input_quant
  226. and not input_quant.dynamic))
  227. if self._is_static_tensor_w8a8(weight_quant, input_quant):
  228. return CompressedTensorsW8A8Int8(
  229. strategy=weight_quant.strategy,
  230. is_static_input_scheme=True)
  231. if self._is_dynamic_token_w8a8(weight_quant, input_quant):
  232. return CompressedTensorsW8A8Int8(
  233. strategy=weight_quant.strategy,
  234. is_static_input_scheme=False)
  235. raise NotImplementedError(
  236. "No compressed-tensors compatible scheme was found.")
  237. def get_scheme(
  238. self,
  239. layer: torch.nn.Module,
  240. layer_name: Optional[str] = None) -> "CompressedTensorsScheme":
  241. """
  242. compressed-tensors supports non uniform in the following way:
  243. ignore: List of layer_names or nn.Module names to be ignored.
  244. targets of config_groups: There can be N config_groups which each
  245. have a quantization scheme. Each config_group has a list of targets
  246. which can be a full layer_name, a regex for a layer_name, or
  247. an nn.Module name.
  248. We first check whether a layer is in the ignore group and use
  249. CompressedTensorsUnquantized (i.e. fp16/bf16) scheme for the layer
  250. We then detect whether a layer_name is found in any target and
  251. use the quantization scheme corresponding to the matched target
  252. to select the CompressedTensorsScheme used for infernece.
  253. """
  254. # Find the "target" in the compressed-tensors config
  255. # that our layer conforms to.
  256. # TODO: add compressed-tensors as dep
  257. # so we do not have to re-write these functions
  258. # need to make accelerate optional in ct to do this
  259. matched_target = find_matched_target(
  260. layer_name=layer_name,
  261. module=layer,
  262. targets=self.target_scheme_map.keys())
  263. # Find the quant_scheme
  264. scheme_dict = self.target_scheme_map[matched_target]
  265. scheme = self._get_scheme_from_parts(
  266. weight_quant=scheme_dict["weights"],
  267. input_quant=scheme_dict["input_activations"])
  268. # Raise error if device does not support the scheme
  269. # (e.g. fp8 needs ada lovelace)
  270. self._check_scheme_supported(scheme.get_min_capability())
  271. return scheme
  272. class CompressedTensorsLinearMethod(LinearMethodBase):
  273. def __init__(self, quantization_config: CompressedTensorsConfig):
  274. self.quantization_config = quantization_config
  275. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  276. layer.scheme.process_weights_after_loading(layer)
  277. def create_weights(self, layer: torch.nn.Module,
  278. input_size_per_partition: int,
  279. output_partition_sizes: List[int], input_size: int,
  280. output_size: int, params_dtype: torch.dtype,
  281. **extra_weight_attrs):
  282. """
  283. Use the CompressedTensorsScheme associated with each layer to create
  284. the necessary parameters for the layer. See LinearMethodBase for param
  285. details
  286. """
  287. weight_loader = extra_weight_attrs.get("weight_loader")
  288. layer.scheme.create_weights(
  289. layer=layer,
  290. input_size=input_size,
  291. input_size_per_partition=input_size_per_partition,
  292. output_partition_sizes=output_partition_sizes,
  293. output_size=output_size,
  294. params_dtype=params_dtype,
  295. weight_loader=weight_loader)
  296. def apply(self,
  297. layer: torch.nn.Module,
  298. x: torch.Tensor,
  299. bias: Optional[torch.Tensor] = None):
  300. """
  301. Use the output of create_weights and the CompressedTensorsScheme
  302. associated with the layer to apply the forward pass with the
  303. layer input. See LinearMethodBase for param details
  304. """
  305. scheme = layer.scheme
  306. if scheme is None:
  307. raise ValueError("A scheme must be defined for each layer")
  308. return scheme.apply_weights(layer, x, bias=bias)
  309. class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
  310. """
  311. Supports loading kv-cache scaling factors from compressed-tensors
  312. checkpoints.
  313. """
  314. def __init__(self, quant_config: CompressedTensorsConfig):
  315. self.validate_kv_cache_scheme(quant_config.kv_cache_scheme)
  316. super().__init__(quant_config)
  317. @staticmethod
  318. def validate_kv_cache_scheme(kv_cache_scheme: Optional[Dict[str, Any]]):
  319. """
  320. Validator for the kv cache scheme. Useful for controlling the
  321. kv cache quantization schemes, that are being supported in Aphrodite
  322. :param kv_cache_scheme: the compressed-tensors kv cache scheme
  323. """
  324. if kv_cache_scheme is None:
  325. return
  326. type_ = kv_cache_scheme.get("type")
  327. num_bits = kv_cache_scheme.get("num_bits")
  328. if type_ != "float" and num_bits != 8:
  329. raise NotImplementedError(
  330. "Currently supported kv cache quantization is "
  331. "num_bits=8, type=float, however "
  332. f"received num_bits={num_bits}, type={type_}")
  333. strategy = kv_cache_scheme.get("strategy")
  334. if strategy != "tensor":
  335. raise NotImplementedError(
  336. "Only support per-tensor scaling factor "
  337. "for compressed-tensors KV cache. "
  338. f"Expected strategy: tensor, found strategy: {strategy}")
  339. is_symmetric = kv_cache_scheme.get("symmetric")
  340. if not is_symmetric:
  341. raise NotImplementedError(
  342. "Only support symmetric scaling factor "
  343. "for compressed-tensors KV cache. "
  344. f"However found symmetric: {is_symmetric}")