1
0

utils.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. import re
  2. from enum import Enum
  3. from typing import Any, Dict, Iterable, Optional, Union
  4. from pydantic import BaseModel, Field, field_validator
  5. from torch.nn import Module
  6. from aphrodite.quantization.utils.quant_utils import FUSED_LAYER_NAME_MAPPING
  7. class CompressionFormat(Enum):
  8. dense = "dense"
  9. sparse_bitmask = "sparse-bitmask"
  10. naive_quantized = "naive-quantized"
  11. float_quantized = "float-quantized"
  12. int_quantized = "int-quantized"
  13. pack_quantized = "pack-quantized"
  14. marlin_24 = "marlin-24"
  15. class QuantizationType(str, Enum):
  16. """
  17. Enum storing quantization type options
  18. """
  19. INT = "int"
  20. FLOAT = "float"
  21. class QuantizationStrategy(str, Enum):
  22. """
  23. Enum storing quantization strategy options
  24. """
  25. TENSOR = "tensor"
  26. CHANNEL = "channel"
  27. GROUP = "group"
  28. BLOCK = "block"
  29. TOKEN = "token"
  30. class ActivationOrdering(str, Enum):
  31. """
  32. Enum storing strategies for activation ordering
  33. Group: reorder groups and weight\n
  34. Weight: only reorder weight, not groups. Slightly lower latency and
  35. accuracy compared to group actorder\n
  36. """
  37. GROUP = "group"
  38. WEIGHT = "weight"
  39. class QuantizationArgs(BaseModel):
  40. """
  41. User facing arguments used to define a quantization config
  42. for weights or activations
  43. :param num_bits: quantization bit depth
  44. :param type: dtype to quantized to, either int or float
  45. :param symmetric: whether or not quantization scale is symmetric
  46. :param strategy: string determining the scope of scale/zero-point to apply
  47. :param group_size: group length to use for the group strategy
  48. :param block_structure: 2d block structure to use for the block
  49. strategy, must be of the format "2x4", "8x16", etc.
  50. :param dynamic: set True to perform dynamic quantization -
  51. values will not be calibrated during calibration phase,
  52. instead during inference new quantization ranges will be
  53. observed with every sample. Defaults to False for static
  54. quantization. Note that enabling dynamic quantization
  55. will change the default observer to a memoryless one
  56. :param actorder: whether to apply group quantization in decreasing order of
  57. activation. Defaults to None for arbitrary ordering
  58. """
  59. num_bits: int = 8
  60. type: QuantizationType = QuantizationType.INT
  61. symmetric: bool = True
  62. group_size: Optional[int] = None
  63. strategy: Optional[QuantizationStrategy] = None
  64. block_structure: Optional[str] = None
  65. dynamic: bool = False
  66. actorder: Union[ActivationOrdering, bool, None] = None
  67. observer: str = Field(
  68. default="minmax",
  69. description=("The class to use to compute the quantization param - "
  70. "scale and zero-point'"),
  71. )
  72. observer_kwargs: Dict[str, Any] = Field(
  73. default_factory=dict,
  74. description=
  75. ("optional dict of kwargs to be passed directly to torch quantization "
  76. "Observers constructor excluding quantization range or symmetry"),
  77. )
  78. @field_validator("actorder", mode="before")
  79. def validate_actorder(cls, value) -> Optional[ActivationOrdering]:
  80. if isinstance(value, bool):
  81. return ActivationOrdering.GROUP if value else None
  82. if isinstance(value, str):
  83. return ActivationOrdering(value.lower())
  84. return value
  85. def is_activation_quantization_format(format: str) -> bool:
  86. _ACTIVATION_QUANTIZATION_FORMATS = [
  87. CompressionFormat.naive_quantized.value,
  88. CompressionFormat.int_quantized.value,
  89. CompressionFormat.float_quantized.value
  90. ]
  91. return format in _ACTIVATION_QUANTIZATION_FORMATS
  92. def should_ignore_layer(layer_name: Optional[str],
  93. ignore: Iterable[str]) -> bool:
  94. if layer_name is None:
  95. return False
  96. # layer_name = model.layers.0.self_attn.qkv_proj
  97. # proj_name = qkv_proj
  98. proj_name = layer_name.split(".")[-1]
  99. # Fused layers like gate_up_proj or qkv_proj will not be fused
  100. # in the safetensors checkpoint. So, we convert the name
  101. # from the fused version to unfused + check to make sure that
  102. # each shard of the fused layer has the same scheme.
  103. if proj_name in FUSED_LAYER_NAME_MAPPING:
  104. shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
  105. # Convert fused_name --> [shard_names]
  106. shard_names = [
  107. layer_name.replace(proj_name, shard_proj_name)
  108. for shard_proj_name in shard_proj_names
  109. ]
  110. # Layer should be ignored if shards are ignored.
  111. should_ignore_layer = None
  112. for shard_name in shard_names:
  113. should_ignore_shard = check_equal_or_regex_match(
  114. layer_name=shard_name, targets=ignore)
  115. # If shard_idx=0, set layer ignore to match shard.
  116. if should_ignore_layer is None:
  117. should_ignore_layer = should_ignore_shard
  118. # If shard_idx=1+ confirm scheme matches prior shards.
  119. elif should_ignore_shard != should_ignore_layer:
  120. raise ValueError(f"Found a different quantization schemes for "
  121. f"{shard_proj_names} in {layer_name}. "
  122. "Aphrodite requires all to use the same "
  123. "scheme.")
  124. # Unfused layers like down_proj and o_proj will match
  125. # the safetensors checkpoint already.
  126. else:
  127. should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name,
  128. targets=ignore)
  129. assert should_ignore_layer is not None
  130. return should_ignore_layer
  131. def check_equal_or_regex_match(layer_name: str,
  132. targets: Iterable[str]) -> bool:
  133. """
  134. Checks whether a layer_name is exactly equal or a regex match for
  135. if target starts with 're:' to any target in list.
  136. """
  137. for target in targets:
  138. if _is_equal_or_regex_match(layer_name, target):
  139. return True
  140. return False
  141. def find_matched_target(layer_name: Optional[str], module: Module,
  142. targets: Iterable[str]) -> str:
  143. """
  144. Helper function to look up which "target" in the compressed-tensors
  145. config that a layer corresponds to.
  146. Recall that a compressed-tensors configs has a concept of
  147. config_groups, where each layer can be quantized with with a different
  148. scheme.
  149. targets in each config_group will be a list of either layer names
  150. (or regexes corresponding to layer names) or names of torch Modules.
  151. First, we try to match the layer_name with a target
  152. Second, we try to match the module's name with a target
  153. :param layer_name: layer name
  154. :param module: torch.nn.Module
  155. :param targets: list of targets to match the layer against
  156. """
  157. if layer_name is None:
  158. layer_name = ""
  159. matched_target = (_find_first_match(layer_name, targets)
  160. or _find_first_match(module.__class__.__name__, targets,
  161. True))
  162. if matched_target is None:
  163. raise ValueError(f"Unable to find matching target for {module} in the "
  164. "compressed-tensors config.")
  165. return matched_target
  166. def _find_first_match(value: str,
  167. targets: Iterable[str],
  168. check_contains: bool = False) -> Optional[str]:
  169. """
  170. Returns first element of target that matches value either
  171. exactly or as a regex after 're:'. If check_contains is set to True,
  172. additionally checks if the target string is contained within the value.
  173. :param value: string to compare the list of targets against
  174. :param targets: list of targets to match the layer against
  175. :param check_contains: whether or not to do a substring match
  176. """
  177. for target in targets:
  178. if _is_equal_or_regex_match(value,
  179. target,
  180. check_contains=check_contains):
  181. return target
  182. return None
  183. def get_compressed_tensors_cache_scale(name: str) -> Optional[str]:
  184. """
  185. Check whether the param name matches the format for k/v cache scales
  186. in compressed-tensors. If this is the case, return its equivalent
  187. param name expected by Aphrodite.
  188. :param name: param name
  189. :return: matching param name for KV cache scale in Aphrodite
  190. """
  191. if name.endswith(".output_scale") and ".k_proj" in name:
  192. return name.replace(".k_proj.output_scale", ".attn.k_scale")
  193. if name.endswith(".output_scale") and ".v_proj" in name:
  194. return name.replace(".v_proj.output_scale", ".attn.v_scale")
  195. # If no matches, return None
  196. return None
  197. def _is_equal_or_regex_match(value: str,
  198. target: str,
  199. check_contains: bool = False) -> bool:
  200. """
  201. Checks whether a value is exactly equal or a regex match for target
  202. if target starts with 're:'. If check_contains is set to True,
  203. additionally checks if the target string is contained within the value.
  204. """
  205. if target.startswith("re:"):
  206. pattern = target[3:]
  207. if re.match(pattern, value):
  208. return True
  209. elif check_contains:
  210. if target.lower() in value.lower():
  211. return True
  212. elif target == value:
  213. return True
  214. return False