awq_marlin.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from loguru import logger
  4. from aphrodite import _custom_ops as ops
  5. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  6. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  7. from aphrodite.modeling.parameter import (GroupQuantScaleParameter,
  8. PackedAphroditeParameter)
  9. from aphrodite.quantization.base_config import QuantizationConfig
  10. from aphrodite.quantization.utils import replace_parameter
  11. from aphrodite.quantization.utils.marlin_utils import (
  12. apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
  13. marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
  14. verify_marlin_supported, verify_marlin_supports_shape)
  15. from aphrodite.scalar_type import scalar_types
  16. class AWQMarlinConfig(QuantizationConfig):
  17. """Config class for AWQ Marlin"""
  18. # num_bits -> type
  19. TYPE_MAP = {
  20. 4: scalar_types.uint4,
  21. 8: scalar_types.uint8,
  22. }
  23. def __init__(self, weight_bits: int, group_size: int, has_zp: bool,
  24. lm_head_quantized: bool) -> None:
  25. self.pack_factor = 32 // weight_bits # packed into int32
  26. self.group_size = group_size
  27. self.has_zp = has_zp
  28. self.lm_head_quantized = lm_head_quantized
  29. if weight_bits not in self.TYPE_MAP:
  30. raise ValueError(f"Unsupported num_bits = {weight_bits}. "
  31. f"Supported num_bits = {self.TYPE_MAP.keys()}")
  32. self.quant_type = self.TYPE_MAP[weight_bits]
  33. verify_marlin_supported(self.quant_type,
  34. group_size=self.group_size,
  35. has_zp=self.has_zp)
  36. def __repr__(self) -> str:
  37. return (f"AWQMarlinConfig(quant_type={self.quant_type}, "
  38. f"group_size={self.group_size}, "
  39. f"has_zp={self.has_zp}, "
  40. f"lm_head_quantized={self.lm_head_quantized})")
  41. @classmethod
  42. def get_name(cls) -> str:
  43. return "awq_marlin"
  44. @classmethod
  45. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  46. return [torch.half, torch.bfloat16]
  47. @classmethod
  48. def get_min_capability(cls) -> int:
  49. return 80
  50. @classmethod
  51. def get_config_filenames(cls) -> List[str]:
  52. return ["quantize_config.json"]
  53. @classmethod
  54. def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig":
  55. weight_bits = cls.get_from_keys(config, ["bits"])
  56. group_size = cls.get_from_keys(config, ["group_size"])
  57. has_zp = cls.get_from_keys(config, ["zero_point"])
  58. lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
  59. default=False)
  60. return cls(weight_bits, group_size, has_zp, lm_head_quantized)
  61. @classmethod
  62. def override_quantization_method(cls, hf_quant_cfg,
  63. user_quant) -> Optional[str]:
  64. can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
  65. is_valid_user_quant = (user_quant is None or user_quant == "marlin"
  66. or user_quant == "awq_marlin")
  67. if can_convert and is_valid_user_quant:
  68. msg = ("The model is convertible to {} during runtime."
  69. " Using {} kernel.".format(cls.get_name(), cls.get_name()))
  70. logger.info(msg)
  71. return cls.get_name()
  72. if can_convert and user_quant == "awq":
  73. logger.info("Detected that the model can run with awq_marlin"
  74. ", however you specified quantization=awq explicitly,"
  75. " so forcing awq. Use quantization=awq_marlin for"
  76. " faster inference")
  77. return None
  78. def get_quant_method(self, layer: torch.nn.Module,
  79. prefix: str) -> Optional["AWQMarlinLinearMethod"]:
  80. if (isinstance(layer, LinearBase) or
  81. (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
  82. return AWQMarlinLinearMethod(self)
  83. return None
  84. def get_scaled_act_names(self) -> List[str]:
  85. return []
  86. @classmethod
  87. def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
  88. # Extract data from quant config.
  89. quant_method = quant_config.get("quant_method", "").lower()
  90. num_bits = quant_config.get("bits", None)
  91. group_size = quant_config.get("group_size", None)
  92. has_zp = quant_config.get("zero_point", None)
  93. if quant_method != "awq":
  94. return False
  95. # If we cannot find the info needed in the config, cannot convert.
  96. if (num_bits is None or group_size is None or has_zp is None):
  97. return False
  98. if num_bits not in cls.TYPE_MAP:
  99. return False
  100. return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
  101. group_size=group_size,
  102. has_zp=has_zp)
  103. class AWQMarlinLinearMethod(LinearMethodBase):
  104. """Linear method for AWQ Marlin.
  105. Args:
  106. quant_config: The AWQ Marlin quantization config.
  107. """
  108. def __init__(self, quant_config: AWQMarlinConfig) -> None:
  109. self.quant_config = quant_config
  110. def create_weights(
  111. self,
  112. layer: torch.nn.Module,
  113. input_size_per_partition: int,
  114. output_partition_sizes: List[int],
  115. input_size: int,
  116. output_size: int,
  117. params_dtype: torch.dtype,
  118. **extra_weight_attrs,
  119. ) -> None:
  120. del output_size
  121. output_size_per_partition = sum(output_partition_sizes)
  122. weight_loader = extra_weight_attrs.get("weight_loader")
  123. # Normalize group_size
  124. if self.quant_config.group_size != -1:
  125. group_size = self.quant_config.group_size
  126. else:
  127. group_size = input_size
  128. verify_marlin_supports_shape(
  129. output_size_per_partition=output_size_per_partition,
  130. input_size_per_partition=input_size_per_partition,
  131. input_size=input_size,
  132. group_size=group_size)
  133. qweight = PackedAphroditeParameter(
  134. data=torch.empty(
  135. input_size_per_partition,
  136. output_size_per_partition // self.quant_config.pack_factor,
  137. dtype=torch.int32,
  138. ),
  139. input_dim=0,
  140. output_dim=1,
  141. packed_dim=1,
  142. packed_factor=self.quant_config.pack_factor,
  143. weight_loader=weight_loader)
  144. num_groups = input_size_per_partition // group_size
  145. qzeros = PackedAphroditeParameter(
  146. data=torch.empty(
  147. num_groups,
  148. output_size_per_partition // self.quant_config.pack_factor,
  149. dtype=torch.int32,
  150. ),
  151. input_dim=0,
  152. output_dim=1,
  153. packed_dim=1,
  154. packed_factor=self.quant_config.pack_factor,
  155. weight_loader=weight_loader)
  156. scales = GroupQuantScaleParameter(data=torch.empty(
  157. num_groups,
  158. output_size_per_partition,
  159. dtype=params_dtype,
  160. ),
  161. input_dim=0,
  162. output_dim=1,
  163. weight_loader=weight_loader)
  164. layer.register_parameter("qweight", qweight)
  165. layer.register_parameter("qzeros", qzeros)
  166. layer.register_parameter("scales", scales)
  167. layer.input_size_per_partition = input_size_per_partition
  168. layer.output_size_per_partition = output_size_per_partition
  169. layer.num_groups = num_groups
  170. # TODO: Update this docs
  171. # Checkpoints are serialized in AutoAWQ format, which is different from the
  172. # marlin format. This function is called after the weights are loaded.
  173. # Here, we handle the repacking
  174. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  175. device = layer.qweight.device
  176. layer.qweight = torch.nn.Parameter(layer.qweight.data,
  177. requires_grad=False)
  178. layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
  179. requires_grad=False)
  180. layer.scales = torch.nn.Parameter(layer.scales.data,
  181. requires_grad=False)
  182. # Allocate marlin workspace
  183. layer.workspace = marlin_make_workspace(
  184. layer.output_size_per_partition, device)
  185. # Repack weights from AWQ format to marlin format.
  186. marlin_qweight = ops.awq_marlin_repack(
  187. layer.qweight,
  188. size_k=layer.input_size_per_partition,
  189. size_n=layer.output_size_per_partition,
  190. num_bits=self.quant_config.quant_type.size_bits)
  191. replace_parameter(layer, "qweight", marlin_qweight)
  192. # Permute scales from AWQ format to marlin format.
  193. marlin_scales = marlin_permute_scales(
  194. layer.scales,
  195. size_k=layer.input_size_per_partition,
  196. size_n=layer.output_size_per_partition,
  197. group_size=self.quant_config.group_size)
  198. replace_parameter(layer, "scales", marlin_scales)
  199. # Permute zero-points from AWQ format to marlin format.
  200. marlin_zp = awq_to_marlin_zero_points(
  201. layer.qzeros,
  202. size_k=layer.num_groups,
  203. size_n=layer.output_size_per_partition,
  204. num_bits=self.quant_config.quant_type.size_bits)
  205. replace_parameter(layer, "qzeros", marlin_zp)
  206. # Not-used
  207. layer.g_idx = marlin_make_empty_g_idx(device)
  208. layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
  209. def apply(
  210. self,
  211. layer: torch.nn.Module,
  212. x: torch.Tensor,
  213. bias: Optional[torch.Tensor] = None,
  214. ) -> torch.Tensor:
  215. return apply_awq_marlin_linear(
  216. input=x,
  217. weight=layer.qweight,
  218. weight_scale=layer.scales,
  219. weight_zp=layer.qzeros,
  220. g_idx=layer.g_idx,
  221. g_idx_sort_indices=layer.g_idx_sort_indices,
  222. workspace=layer.workspace,
  223. quant_type=self.quant_config.quant_type,
  224. output_size_per_partition=layer.output_size_per_partition,
  225. input_size_per_partition=layer.input_size_per_partition,
  226. bias=bias)