utils.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. import re
  2. from enum import Enum
  3. from typing import Any, Dict, Iterable, Optional
  4. from pydantic import BaseModel, Field
  5. from torch.nn import Module
  6. class CompressionFormat(Enum):
  7. dense = "dense"
  8. sparse_bitmask = "sparse-bitmask"
  9. naive_quantized = "naive-quantized"
  10. float_quantized = "float-quantized"
  11. int_quantized = "int-quantized"
  12. pack_quantized = "pack-quantized"
  13. marlin_24 = "marlin-24"
  14. class QuantizationType(str, Enum):
  15. """
  16. Enum storing quantization type options
  17. """
  18. INT = "int"
  19. FLOAT = "float"
  20. class QuantizationStrategy(str, Enum):
  21. """
  22. Enum storing quantization strategy options
  23. """
  24. TENSOR = "tensor"
  25. CHANNEL = "channel"
  26. GROUP = "group"
  27. BLOCK = "block"
  28. TOKEN = "token"
  29. class QuantizationArgs(BaseModel):
  30. """
  31. User facing arguments used to define a quantization config
  32. for weights or activations
  33. :param num_bits: quantization bit depth
  34. :param type: dtype to quantized to, either int or float
  35. :param symmetric: whether or not quantization scale is symmetric
  36. :param strategy: string determining the scope of scale/zero-point to apply
  37. :param group_size: group length to use for the group strategy
  38. :param block_structure: 2d block structure to use for the block
  39. strategy, must be of the format "2x4", "8x16", etc.
  40. :param dynamic: set True to perform dynamic quantization -
  41. values will not be calibrated during calibration phase,
  42. instead during inference new quantization ranges will be
  43. observed with every sample. Defaults to False for static
  44. quantization. Note that enabling dynamic quantization
  45. will change the default observer to a memoryless one
  46. """
  47. num_bits: int = 8
  48. type: QuantizationType = QuantizationType.INT
  49. symmetric: bool = True
  50. group_size: Optional[int] = None
  51. strategy: Optional[QuantizationStrategy] = None
  52. block_structure: Optional[str] = None
  53. dynamic: bool = False
  54. observer: str = Field(
  55. default="minmax",
  56. description=("The class to use to compute the quantization param - "
  57. "scale and zero-point'"),
  58. )
  59. observer_kwargs: Dict[str, Any] = Field(
  60. default_factory=dict,
  61. description=
  62. ("optional dict of kwargs to be passed directly to torch quantization "
  63. "Observers constructor excluding quantization range or symmetry"),
  64. )
  65. def is_activation_quantization_format(format: str) -> bool:
  66. _ACTIVATION_QUANTIZATION_FORMATS = [
  67. CompressionFormat.naive_quantized.value,
  68. CompressionFormat.int_quantized.value,
  69. CompressionFormat.float_quantized.value
  70. ]
  71. return format in _ACTIVATION_QUANTIZATION_FORMATS
  72. # fused_name: List[shard_name]
  73. _FUSED_LAYER_NAME_MAPPING = {
  74. "qkv_proj": ["q_proj", "k_proj", "v_proj"],
  75. "gate_up_proj": ["gate_proj", "up_proj"]
  76. }
  77. def should_ignore_layer(layer_name: Optional[str],
  78. ignore: Iterable[str]) -> bool:
  79. if layer_name is None:
  80. return False
  81. # layer_name = model.layers.0.self_attn.qkv_proj
  82. # proj_name = qkv_proj
  83. proj_name = layer_name.split(".")[-1]
  84. # Fused layers like gate_up_proj or qkv_proj will not be fused
  85. # in the safetensors checkpoint. So, we convert the name
  86. # from the fused version to unfused + check to make sure that
  87. # each shard of the fused layer has the same scheme.
  88. if proj_name in _FUSED_LAYER_NAME_MAPPING:
  89. shard_proj_names = _FUSED_LAYER_NAME_MAPPING[proj_name]
  90. # Convert fused_name --> [shard_names]
  91. shard_names = [
  92. layer_name.replace(proj_name, shard_proj_name)
  93. for shard_proj_name in shard_proj_names
  94. ]
  95. # Layer should be ignored if shards are ignored.
  96. should_ignore_layer = None
  97. for shard_name in shard_names:
  98. should_ignore_shard = check_equal_or_regex_match(
  99. layer_name=shard_name, targets=ignore)
  100. # If shard_idx=0, set layer ignore to match shard.
  101. if should_ignore_layer is None:
  102. should_ignore_layer = should_ignore_shard
  103. # If shard_idx=1+ confirm scheme matches prior shards.
  104. elif should_ignore_shard != should_ignore_layer:
  105. raise ValueError(f"Found a different quantization schemes for "
  106. f"{shard_proj_names} in {layer_name}. vLLM "
  107. "requires all to use the same scheme.")
  108. # Unfused layers like down_proj and o_proj will match
  109. # the safetensors checkpoint already.
  110. else:
  111. should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name,
  112. targets=ignore)
  113. assert should_ignore_layer is not None
  114. return should_ignore_layer
  115. def check_equal_or_regex_match(layer_name: str,
  116. targets: Iterable[str]) -> bool:
  117. """
  118. Checks whether a layer_name is exactly equal or a regex match for
  119. if target starts with 're:' to any target in list.
  120. """
  121. for target in targets:
  122. if _is_equal_or_regex_match(layer_name, target):
  123. return True
  124. return False
  125. def find_matched_target(layer_name: Optional[str], module: Module,
  126. targets: Iterable[str]) -> str:
  127. """
  128. Helper function to look up which "target" in the compressed-tensors
  129. config that a layer corresponds to.
  130. Recall that a compressed-tensors configs has a concept of
  131. config_groups, where each layer can be quantized with with a different
  132. scheme.
  133. targets in each config_group will be a list of either layer names
  134. (or regexes corresponding to layer names) or names of torch Modules.
  135. First, we try to match the layer_name with a target
  136. Second, we try to match the module's name with a target
  137. :param layer_name: layer name
  138. :param module: torch.nn.Module
  139. :param targets: list of targets to match the layer against
  140. """
  141. if layer_name is None:
  142. layer_name = ""
  143. matched_target = (_find_first_match(layer_name, targets)
  144. or _find_first_match(module.__class__.__name__, targets,
  145. True))
  146. if matched_target is None:
  147. raise ValueError(f"Unable to find matching target for {module} in the "
  148. "compressed-tensors config.")
  149. return matched_target
  150. def _find_first_match(value: str,
  151. targets: Iterable[str],
  152. check_contains: bool = False) -> Optional[str]:
  153. """
  154. Returns first element of target that matches value either
  155. exactly or as a regex after 're:'. If check_contains is set to True,
  156. additionally checks if the target string is contained within the value.
  157. :param value: string to compare the list of targets against
  158. :param targets: list of targets to match the layer against
  159. :param check_contains: whether or not to do a substring match
  160. """
  161. for target in targets:
  162. if _is_equal_or_regex_match(value,
  163. target,
  164. check_contains=check_contains):
  165. return target
  166. return None
  167. def _is_equal_or_regex_match(value: str,
  168. target: str,
  169. check_contains: bool = False) -> bool:
  170. """
  171. Checks whether a value is exactly equal or a regex match for target
  172. if target starts with 're:'. If check_contains is set to True,
  173. additionally checks if the target string is contained within the value.
  174. """
  175. if target.startswith("re:"):
  176. pattern = target[3:]
  177. if re.match(pattern, value):
  178. return True
  179. elif check_contains:
  180. if target.lower() in value.lower():
  181. return True
  182. elif target == value:
  183. return True
  184. return False