compressed_tensors.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from pydantic import BaseModel
  4. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  5. from aphrodite.quantization.base_config import QuantizationConfig # noqa: E501
  6. from aphrodite.quantization.compressed_tensors.schemes import (
  7. CompressedTensorsScheme, CompressedTensorsW8A8DynamicToken,
  8. CompressedTensorsW8A8StaticTensor)
  9. from aphrodite.quantization.compressed_tensors.utils import (
  10. QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match)
  11. class CompressedTensorsConfig(QuantizationConfig):
  12. def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str]):
  13. self.ignore = ignore
  14. self.layer_quant_details = layer_quant_details
  15. def get_linear_method(self) -> "CompressedTensorsLinearMethod":
  16. return CompressedTensorsLinearMethod(self)
  17. def get_scaled_act_names(self) -> List[str]:
  18. return []
  19. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  20. return [torch.float16]
  21. # Need to figure it out
  22. def get_min_capability(self) -> int:
  23. return 60
  24. def get_name(self) -> str:
  25. return "compressed_tensors"
  26. def get_quant_method(
  27. self, layer: torch.nn.Module
  28. ) -> Optional["CompressedTensorsLinearMethod"]:
  29. if isinstance(layer, LinearBase):
  30. return CompressedTensorsLinearMethod(self)
  31. return None
  32. @classmethod
  33. def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
  34. layer_quant_details: Dict[str, Any] = dict()
  35. ignore: List[str] = config.get("ignore", None)
  36. for key, quant_config in config["config_groups"].items():
  37. targets = quant_config.get("targets")
  38. for target in targets:
  39. layer_quant_details[target] = {}
  40. layer_quant_details[target][
  41. "weight"] = QuantizationArgs.parse_obj(
  42. quant_config.get("weights"))
  43. layer_quant_details[target][
  44. "input"] = QuantizationArgs.parse_obj(
  45. quant_config.get("input_activations"))
  46. return cls(layer_quant_details=layer_quant_details, ignore=ignore)
  47. @classmethod
  48. def get_config_filenames(cls) -> List[str]:
  49. return []
  50. def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
  51. input_quant: BaseModel) -> bool:
  52. is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
  53. is_tensor = (weight_quant.strategy == input_quant.strategy ==
  54. QuantizationStrategy.TENSOR.value)
  55. is_symmetric = weight_quant.symmetric and input_quant.symmetric
  56. is_static = not weight_quant.dynamic and not input_quant.dynamic
  57. return is_8_bits and is_tensor and is_symmetric and is_static
  58. def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
  59. input_quant: BaseModel) -> bool:
  60. is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
  61. is_token_tensor = (weight_quant.strategy
  62. == QuantizationStrategy.TENSOR.value) and (
  63. input_quant.strategy
  64. == QuantizationStrategy.TOKEN.value)
  65. is_symmetric = weight_quant.symmetric and input_quant.symmetric
  66. is_dynamic = not weight_quant.dynamic and input_quant.dynamic
  67. return is_8_bits and is_token_tensor and is_symmetric and is_dynamic
  68. def _get_schema(self, weight_quant: BaseModel,
  69. input_quant: BaseModel) -> "CompressedTensorsScheme":
  70. if self._is_static_tensor_w8a8(weight_quant, input_quant):
  71. return CompressedTensorsW8A8StaticTensor()
  72. if self._is_dynamic_token_w8a8(weight_quant, input_quant):
  73. return CompressedTensorsW8A8DynamicToken()
  74. raise NotImplementedError("Scheme not supported.")
  75. def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
  76. layer_type_name = find_first_name_or_class_match(
  77. name="",
  78. module=layer,
  79. targets=self.layer_quant_details.keys(),
  80. check_contains=True)
  81. if layer_type_name is None:
  82. raise ValueError(f"Could not matching target for layer {layer}")
  83. layer_quant_details: Dict[str, Any] = self.layer_quant_details.get(
  84. layer_type_name, None)
  85. if layer_quant_details is None:
  86. raise ValueError(
  87. f"Could not find quantization details for {layer}.")
  88. return self._get_schema(weight_quant=layer_quant_details["weight"],
  89. input_quant=layer_quant_details["input"])
  90. class CompressedTensorsLinearMethod(LinearMethodBase):
  91. def __init__(self, quantization_config: CompressedTensorsConfig):
  92. self.quantization_config = quantization_config
  93. def create_weights(self, layer: torch.nn.Module,
  94. input_size_per_partition: int,
  95. output_partition_sizes: List[int], input_size: int,
  96. output_size: int, params_dtype: torch.dtype,
  97. **extra_weight_attrs):
  98. """
  99. Use the CompressedTensorsScheme associated with each layer to create
  100. the necessary parameters for the layer. See LinearMethodBase for param
  101. details
  102. """
  103. weight_loader = extra_weight_attrs.get("weight_loader")
  104. scheme = self.quantization_config.get_scheme(layer=layer)
  105. scheme.create_weights(
  106. layer=layer,
  107. input_size_per_partition=input_size_per_partition,
  108. output_partition_sizes=output_partition_sizes,
  109. output_size=output_size,
  110. params_dtype=params_dtype,
  111. weight_loader=weight_loader)
  112. layer.scheme = scheme
  113. def apply(self,
  114. layer: torch.nn.Module,
  115. x: torch.Tensor,
  116. bias: Optional[torch.Tensor] = None):
  117. """
  118. Use the output of create_weights and the CompressedTensorsScheme
  119. associated with the layer to apply the forward pass with the
  120. layer input. See LinearMethodBase for param details
  121. """
  122. if bias is not None:
  123. raise ValueError("bias is not supported for this linear method")
  124. scheme = layer.scheme
  125. if scheme is None:
  126. raise ValueError("A scheme must be defined for each layer")
  127. return scheme.apply_weights(layer, x)