awq_marlin.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from loguru import logger
  4. from torch.nn.parameter import Parameter
  5. from aphrodite import _custom_ops as ops
  6. from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
  7. set_weight_attrs)
  8. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  9. from aphrodite.quantization.base_config import QuantizationConfig
  10. from aphrodite.quantization.utils.marlin_utils import (
  11. apply_awq_marlin_linear, awq_to_marlin_zero_points,
  12. check_awq_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace,
  13. marlin_permute_scales, replace_tensor, verify_awq_marlin_supported,
  14. verify_marlin_supports_shape)
  15. class AWQMarlinConfig(QuantizationConfig):
  16. """Config class for AWQ Marlin"""
  17. def __init__(self, weight_bits: int, group_size: int, has_zp: bool,
  18. lm_head_quantized: bool) -> None:
  19. self.weight_bits = weight_bits
  20. self.pack_factor = 32 // self.weight_bits # packed into int32
  21. self.group_size = group_size
  22. self.has_zp = has_zp
  23. self.lm_head_quantized = lm_head_quantized
  24. verify_awq_marlin_supported(num_bits=self.weight_bits,
  25. group_size=self.group_size,
  26. has_zp=self.has_zp)
  27. def __repr__(self) -> str:
  28. return (f"AWQMarlinConfig(weight_bits={self.weight_bits}, "
  29. f"group_size={self.group_size}, "
  30. f"has_zp={self.has_zp}, "
  31. f"lm_head_quantized={self.lm_head_quantized})")
  32. @classmethod
  33. def get_name(cls) -> str:
  34. return "awq_marlin"
  35. @classmethod
  36. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  37. return [torch.half, torch.bfloat16]
  38. @classmethod
  39. def get_min_capability(cls) -> int:
  40. return 80
  41. @classmethod
  42. def get_config_filenames(cls) -> List[str]:
  43. return ["quantize_config.json"]
  44. @classmethod
  45. def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig":
  46. weight_bits = cls.get_from_keys(config, ["bits"])
  47. group_size = cls.get_from_keys(config, ["group_size"])
  48. has_zp = cls.get_from_keys(config, ["zero_point"])
  49. lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
  50. default=False)
  51. return cls(weight_bits, group_size, has_zp, lm_head_quantized)
  52. @classmethod
  53. def override_quantization_method(cls, hf_quant_cfg,
  54. user_quant) -> Optional[str]:
  55. can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
  56. is_valid_user_quant = (user_quant is None or user_quant == "marlin")
  57. if can_convert and is_valid_user_quant:
  58. msg = ("The model is convertible to {} during runtime."
  59. " Using {} kernel.".format(cls.get_name(), cls.get_name()))
  60. logger.info(msg)
  61. return cls.get_name()
  62. if can_convert and user_quant == "awq":
  63. logger.info("Detected that the model can run with awq_marlin"
  64. ", however you specified quantization=awq explicitly,"
  65. " so forcing awq. Use quantization=awq_marlin for"
  66. " faster inference")
  67. return None
  68. def get_quant_method(self, layer: torch.nn.Module,
  69. prefix: str) -> Optional["AWQMarlinLinearMethod"]:
  70. if (isinstance(layer, LinearBase) or
  71. (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
  72. return AWQMarlinLinearMethod(self)
  73. return None
  74. def get_scaled_act_names(self) -> List[str]:
  75. return []
  76. @classmethod
  77. def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
  78. # Extract data from quant config.
  79. quant_method = quant_config.get("quant_method", "").lower()
  80. num_bits = quant_config.get("bits", None)
  81. group_size = quant_config.get("group_size", None)
  82. has_zp = quant_config.get("zero_point", None)
  83. if quant_method != "awq":
  84. return False
  85. # If we cannot find the info needed in the config, cannot convert.
  86. if (num_bits is None or group_size is None or has_zp is None):
  87. return False
  88. return check_awq_marlin_supported(
  89. num_bits=num_bits,
  90. group_size=group_size,
  91. has_zp=has_zp,
  92. min_capability=cls.get_min_capability())
  93. class AWQMarlinLinearMethod(LinearMethodBase):
  94. """Linear method for AWQ Marlin.
  95. Args:
  96. quant_config: The AWQ Marlin quantization config.
  97. """
  98. def __init__(self, quant_config: AWQMarlinConfig) -> None:
  99. self.quant_config = quant_config
  100. def create_weights(
  101. self,
  102. layer: torch.nn.Module,
  103. input_size_per_partition: int,
  104. output_partition_sizes: List[int],
  105. input_size: int,
  106. output_size: int,
  107. params_dtype: torch.dtype,
  108. **extra_weight_attrs,
  109. ) -> None:
  110. del output_size
  111. output_size_per_partition = sum(output_partition_sizes)
  112. # Normalize group_size
  113. if self.quant_config.group_size != -1:
  114. group_size = self.quant_config.group_size
  115. else:
  116. group_size = input_size
  117. verify_marlin_supports_shape(
  118. output_size_per_partition=output_size_per_partition,
  119. input_size_per_partition=input_size_per_partition,
  120. input_size=input_size,
  121. group_size=group_size)
  122. qweight = Parameter(
  123. torch.empty(
  124. input_size_per_partition,
  125. output_size_per_partition // self.quant_config.pack_factor,
  126. dtype=torch.int32,
  127. ),
  128. requires_grad=False,
  129. )
  130. set_weight_attrs(
  131. qweight, {
  132. "input_dim": 0,
  133. "output_dim": 1,
  134. "packed_dim": 1,
  135. "pack_factor": self.quant_config.pack_factor,
  136. })
  137. num_groups = input_size_per_partition // group_size
  138. qzeros = Parameter(
  139. torch.empty(
  140. num_groups,
  141. output_size_per_partition // self.quant_config.pack_factor,
  142. dtype=torch.int32,
  143. ),
  144. requires_grad=False,
  145. )
  146. set_weight_attrs(
  147. qzeros, {
  148. "input_dim": 0,
  149. "output_dim": 1,
  150. "packed_dim": 1,
  151. "pack_factor": self.quant_config.pack_factor,
  152. })
  153. scales = Parameter(
  154. torch.empty(
  155. num_groups,
  156. output_size_per_partition,
  157. dtype=params_dtype,
  158. ),
  159. requires_grad=False,
  160. )
  161. set_weight_attrs(scales, {
  162. "input_dim": 0,
  163. "output_dim": 1,
  164. })
  165. layer.register_parameter("qweight", qweight)
  166. set_weight_attrs(qweight, extra_weight_attrs)
  167. layer.register_parameter("qzeros", qzeros)
  168. set_weight_attrs(qzeros, extra_weight_attrs)
  169. layer.register_parameter("scales", scales)
  170. set_weight_attrs(scales, extra_weight_attrs)
  171. layer.input_size_per_partition = input_size_per_partition
  172. layer.output_size_per_partition = output_size_per_partition
  173. layer.num_groups = num_groups
  174. # TODO: Update this docs
  175. # Checkpoints are serialized in AutoAWQ format, which is different from the
  176. # marlin format. This function is called after the weights are loaded.
  177. # Here, we handle the repacking
  178. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  179. device = layer.qweight.device
  180. # Allocate marlin workspace
  181. layer.workspace = marlin_make_workspace(
  182. layer.output_size_per_partition, device)
  183. # Repack weights from AWQ format to marlin format.
  184. marlin_qweight = ops.awq_marlin_repack(
  185. layer.qweight,
  186. size_k=layer.input_size_per_partition,
  187. size_n=layer.output_size_per_partition,
  188. num_bits=self.quant_config.weight_bits)
  189. replace_tensor(layer, "qweight", marlin_qweight)
  190. # Permute scales from AWQ format to marlin format.
  191. marlin_scales = marlin_permute_scales(
  192. layer.scales,
  193. size_k=layer.input_size_per_partition,
  194. size_n=layer.output_size_per_partition,
  195. group_size=self.quant_config.group_size)
  196. replace_tensor(layer, "scales", marlin_scales)
  197. # Permute zero-points from AWQ format to marlin format.
  198. marlin_zp = awq_to_marlin_zero_points(
  199. layer.qzeros,
  200. size_k=layer.num_groups,
  201. size_n=layer.output_size_per_partition,
  202. num_bits=self.quant_config.weight_bits)
  203. replace_tensor(layer, "qzeros", marlin_zp)
  204. # Not-used
  205. layer.g_idx = marlin_make_empty_g_idx(device)
  206. layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
  207. def apply(
  208. self,
  209. layer: torch.nn.Module,
  210. x: torch.Tensor,
  211. bias: Optional[torch.Tensor] = None,
  212. ) -> torch.Tensor:
  213. return apply_awq_marlin_linear(
  214. input=x,
  215. weight=layer.qweight,
  216. weight_scale=layer.scales,
  217. weight_zp=layer.qzeros,
  218. g_idx=layer.g_idx,
  219. g_idx_sort_indices=layer.g_idx_sort_indices,
  220. workspace=layer.workspace,
  221. num_bits=self.quant_config.weight_bits,
  222. output_size_per_partition=layer.output_size_per_partition,
  223. input_size_per_partition=layer.input_size_per_partition,
  224. bias=bias)