awq_marlin.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from loguru import logger
  4. from torch.nn.parameter import Parameter
  5. from aphrodite import _custom_ops as ops
  6. from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
  7. set_weight_attrs)
  8. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  9. from aphrodite.quantization.base_config import QuantizationConfig
  10. from aphrodite.quantization.utils.marlin_utils import (
  11. apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
  12. marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
  13. replace_tensor, verify_marlin_supported, verify_marlin_supports_shape)
  14. from aphrodite.scalar_type import scalar_types
  15. class AWQMarlinConfig(QuantizationConfig):
  16. """Config class for AWQ Marlin"""
  17. # num_bits -> type
  18. TYPE_MAP = {
  19. 4: scalar_types.uint4,
  20. 8: scalar_types.uint8,
  21. }
  22. def __init__(self, weight_bits: int, group_size: int, has_zp: bool,
  23. lm_head_quantized: bool) -> None:
  24. self.pack_factor = 32 // weight_bits # packed into int32
  25. self.group_size = group_size
  26. self.has_zp = has_zp
  27. self.lm_head_quantized = lm_head_quantized
  28. if weight_bits not in self.TYPE_MAP:
  29. raise ValueError(f"Unsupported num_bits = {weight_bits}. "
  30. f"Supported num_bits = {self.TYPE_MAP.keys()}")
  31. self.quant_type = self.TYPE_MAP[weight_bits]
  32. verify_marlin_supported(self.quant_type,
  33. group_size=self.group_size,
  34. has_zp=self.has_zp)
  35. def __repr__(self) -> str:
  36. return (f"AWQMarlinConfig(quant_type={self.quant_type}, "
  37. f"group_size={self.group_size}, "
  38. f"has_zp={self.has_zp}, "
  39. f"lm_head_quantized={self.lm_head_quantized})")
  40. @classmethod
  41. def get_name(cls) -> str:
  42. return "awq_marlin"
  43. @classmethod
  44. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  45. return [torch.half, torch.bfloat16]
  46. @classmethod
  47. def get_min_capability(cls) -> int:
  48. return 80
  49. @classmethod
  50. def get_config_filenames(cls) -> List[str]:
  51. return ["quantize_config.json"]
  52. @classmethod
  53. def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig":
  54. weight_bits = cls.get_from_keys(config, ["bits"])
  55. group_size = cls.get_from_keys(config, ["group_size"])
  56. has_zp = cls.get_from_keys(config, ["zero_point"])
  57. lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
  58. default=False)
  59. return cls(weight_bits, group_size, has_zp, lm_head_quantized)
  60. @classmethod
  61. def override_quantization_method(cls, hf_quant_cfg,
  62. user_quant) -> Optional[str]:
  63. can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
  64. is_valid_user_quant = (user_quant is None or user_quant == "marlin"
  65. or user_quant == "awq_marlin")
  66. if can_convert and is_valid_user_quant:
  67. msg = ("The model is convertible to {} during runtime."
  68. " Using {} kernel.".format(cls.get_name(), cls.get_name()))
  69. logger.info(msg)
  70. return cls.get_name()
  71. if can_convert and user_quant == "awq":
  72. logger.info("Detected that the model can run with awq_marlin"
  73. ", however you specified quantization=awq explicitly,"
  74. " so forcing awq. Use quantization=awq_marlin for"
  75. " faster inference")
  76. return None
  77. def get_quant_method(self, layer: torch.nn.Module,
  78. prefix: str) -> Optional["AWQMarlinLinearMethod"]:
  79. if (isinstance(layer, LinearBase) or
  80. (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
  81. return AWQMarlinLinearMethod(self)
  82. return None
  83. def get_scaled_act_names(self) -> List[str]:
  84. return []
  85. @classmethod
  86. def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
  87. # Extract data from quant config.
  88. quant_method = quant_config.get("quant_method", "").lower()
  89. num_bits = quant_config.get("bits", None)
  90. group_size = quant_config.get("group_size", None)
  91. has_zp = quant_config.get("zero_point", None)
  92. if quant_method != "awq":
  93. return False
  94. # If we cannot find the info needed in the config, cannot convert.
  95. if (num_bits is None or group_size is None or has_zp is None):
  96. return False
  97. if num_bits not in cls.TYPE_MAP:
  98. return False
  99. return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
  100. group_size=group_size,
  101. has_zp=has_zp,
  102. min_capability=cls.get_min_capability())
  103. class AWQMarlinLinearMethod(LinearMethodBase):
  104. """Linear method for AWQ Marlin.
  105. Args:
  106. quant_config: The AWQ Marlin quantization config.
  107. """
  108. def __init__(self, quant_config: AWQMarlinConfig) -> None:
  109. self.quant_config = quant_config
  110. def create_weights(
  111. self,
  112. layer: torch.nn.Module,
  113. input_size_per_partition: int,
  114. output_partition_sizes: List[int],
  115. input_size: int,
  116. output_size: int,
  117. params_dtype: torch.dtype,
  118. **extra_weight_attrs,
  119. ) -> None:
  120. del output_size
  121. output_size_per_partition = sum(output_partition_sizes)
  122. # Normalize group_size
  123. if self.quant_config.group_size != -1:
  124. group_size = self.quant_config.group_size
  125. else:
  126. group_size = input_size
  127. verify_marlin_supports_shape(
  128. output_size_per_partition=output_size_per_partition,
  129. input_size_per_partition=input_size_per_partition,
  130. input_size=input_size,
  131. group_size=group_size)
  132. qweight = Parameter(
  133. torch.empty(
  134. input_size_per_partition,
  135. output_size_per_partition // self.quant_config.pack_factor,
  136. dtype=torch.int32,
  137. ),
  138. requires_grad=False,
  139. )
  140. set_weight_attrs(
  141. qweight, {
  142. "input_dim": 0,
  143. "output_dim": 1,
  144. "packed_dim": 1,
  145. "pack_factor": self.quant_config.pack_factor,
  146. })
  147. num_groups = input_size_per_partition // group_size
  148. qzeros = Parameter(
  149. torch.empty(
  150. num_groups,
  151. output_size_per_partition // self.quant_config.pack_factor,
  152. dtype=torch.int32,
  153. ),
  154. requires_grad=False,
  155. )
  156. set_weight_attrs(
  157. qzeros, {
  158. "input_dim": 0,
  159. "output_dim": 1,
  160. "packed_dim": 1,
  161. "pack_factor": self.quant_config.pack_factor,
  162. })
  163. scales = Parameter(
  164. torch.empty(
  165. num_groups,
  166. output_size_per_partition,
  167. dtype=params_dtype,
  168. ),
  169. requires_grad=False,
  170. )
  171. set_weight_attrs(scales, {
  172. "input_dim": 0,
  173. "output_dim": 1,
  174. })
  175. layer.register_parameter("qweight", qweight)
  176. set_weight_attrs(qweight, extra_weight_attrs)
  177. layer.register_parameter("qzeros", qzeros)
  178. set_weight_attrs(qzeros, extra_weight_attrs)
  179. layer.register_parameter("scales", scales)
  180. set_weight_attrs(scales, extra_weight_attrs)
  181. layer.input_size_per_partition = input_size_per_partition
  182. layer.output_size_per_partition = output_size_per_partition
  183. layer.num_groups = num_groups
  184. # TODO: Update this docs
  185. # Checkpoints are serialized in AutoAWQ format, which is different from the
  186. # marlin format. This function is called after the weights are loaded.
  187. # Here, we handle the repacking
  188. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  189. device = layer.qweight.device
  190. # Allocate marlin workspace
  191. layer.workspace = marlin_make_workspace(
  192. layer.output_size_per_partition, device)
  193. # Repack weights from AWQ format to marlin format.
  194. marlin_qweight = ops.awq_marlin_repack(
  195. layer.qweight,
  196. size_k=layer.input_size_per_partition,
  197. size_n=layer.output_size_per_partition,
  198. num_bits=self.quant_config.quant_type.size_bits)
  199. replace_tensor(layer, "qweight", marlin_qweight)
  200. # Permute scales from AWQ format to marlin format.
  201. marlin_scales = marlin_permute_scales(
  202. layer.scales,
  203. size_k=layer.input_size_per_partition,
  204. size_n=layer.output_size_per_partition,
  205. group_size=self.quant_config.group_size)
  206. replace_tensor(layer, "scales", marlin_scales)
  207. # Permute zero-points from AWQ format to marlin format.
  208. marlin_zp = awq_to_marlin_zero_points(
  209. layer.qzeros,
  210. size_k=layer.num_groups,
  211. size_n=layer.output_size_per_partition,
  212. num_bits=self.quant_config.quant_type.size_bits)
  213. replace_tensor(layer, "qzeros", marlin_zp)
  214. # Not-used
  215. layer.g_idx = marlin_make_empty_g_idx(device)
  216. layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
  217. def apply(
  218. self,
  219. layer: torch.nn.Module,
  220. x: torch.Tensor,
  221. bias: Optional[torch.Tensor] = None,
  222. ) -> torch.Tensor:
  223. return apply_awq_marlin_linear(
  224. input=x,
  225. weight=layer.qweight,
  226. weight_scale=layer.scales,
  227. weight_zp=layer.qzeros,
  228. g_idx=layer.g_idx,
  229. g_idx_sort_indices=layer.g_idx_sort_indices,
  230. workspace=layer.workspace,
  231. quant_type=self.quant_config.quant_type,
  232. output_size_per_partition=layer.output_size_per_partition,
  233. input_size_per_partition=layer.input_size_per_partition,
  234. bias=bias)