gptq_marlin.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from loguru import logger
  4. from torch.nn.parameter import Parameter
  5. from aphrodite import _custom_ops as ops
  6. from aphrodite.modeling.layers.linear import (LinearBase, LinearMethodBase,
  7. set_weight_attrs)
  8. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  9. from aphrodite.quantization.base_config import QuantizationConfig
  10. from aphrodite.quantization.utils.marlin_utils import (
  11. apply_gptq_marlin_linear, check_gptq_marlin_supported, marlin_is_k_full,
  12. marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
  13. marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
  14. verify_gptq_marlin_supported, verify_marlin_supports_shape)
  15. class GPTQMarlinConfig(QuantizationConfig):
  16. """Config class for GPTQ Marlin"""
  17. def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
  18. is_sym: bool, lm_head_quantized: bool) -> None:
  19. if desc_act and group_size == -1:
  20. # In this case, act_order == True is the same as act_order == False
  21. # (since we have only one group per output channel)
  22. desc_act = False
  23. self.weight_bits = weight_bits
  24. self.pack_factor = 32 // self.weight_bits # packed into int32
  25. self.group_size = group_size
  26. self.desc_act = desc_act
  27. self.is_sym = is_sym
  28. self.lm_head_quantized = lm_head_quantized
  29. # Verify supported on platform.
  30. verify_gptq_marlin_supported(num_bits=self.weight_bits,
  31. group_size=self.group_size,
  32. is_sym=self.is_sym)
  33. def __repr__(self) -> str:
  34. return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
  35. f"group_size={self.group_size}, "
  36. f"desc_act={self.desc_act}, "
  37. f"lm_head_quantized={self.lm_head_quantized})")
  38. @classmethod
  39. def get_name(cls) -> str:
  40. return "gptq_marlin"
  41. @classmethod
  42. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  43. return [torch.half, torch.bfloat16]
  44. @classmethod
  45. def get_min_capability(cls) -> int:
  46. return 80
  47. @classmethod
  48. def get_config_filenames(cls) -> List[str]:
  49. return ["quantize_config.json"]
  50. @classmethod
  51. def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
  52. weight_bits = cls.get_from_keys(config, ["bits"])
  53. group_size = cls.get_from_keys(config, ["group_size"])
  54. desc_act = cls.get_from_keys(config, ["desc_act"])
  55. is_sym = cls.get_from_keys(config, ["sym"])
  56. lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
  57. default=False)
  58. return cls(weight_bits, group_size, desc_act, is_sym,
  59. lm_head_quantized)
  60. @classmethod
  61. def override_quantization_method(cls, hf_quant_cfg,
  62. user_quant) -> Optional[str]:
  63. can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
  64. is_valid_user_quant = (user_quant is None or user_quant == "marlin")
  65. if can_convert and is_valid_user_quant:
  66. msg = ("The model is convertible to {} during runtime."
  67. " Using {} kernel.".format(cls.get_name(), cls.get_name()))
  68. logger.info(msg)
  69. return cls.get_name()
  70. if can_convert and user_quant == "gptq":
  71. logger.info("Detected that the model can run with gptq_marlin"
  72. ", however you specified quantization=gptq explicitly,"
  73. " so forcing gptq. Use quantization=gptq_marlin for"
  74. " faster inference")
  75. return None
  76. def get_quant_method(self, layer: torch.nn.Module,
  77. prefix: str) -> Optional["GPTQMarlinLinearMethod"]:
  78. if (isinstance(layer, LinearBase) or
  79. (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
  80. return GPTQMarlinLinearMethod(self)
  81. return None
  82. def get_scaled_act_names(self) -> List[str]:
  83. return []
  84. @classmethod
  85. def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
  86. # Extract data from quant config.
  87. quant_method = quant_config.get("quant_method", "").lower()
  88. num_bits = quant_config.get("bits", None)
  89. group_size = quant_config.get("group_size", None)
  90. sym = quant_config.get("sym", None)
  91. desc_act = quant_config.get("desc_act", None)
  92. if quant_method != "gptq":
  93. return False
  94. # If we cannot find the info needed in the config, cannot convert.
  95. if (num_bits is None or group_size is None or sym is None
  96. or desc_act is None):
  97. return False
  98. return check_gptq_marlin_supported(
  99. num_bits=num_bits,
  100. group_size=group_size,
  101. is_sym=sym,
  102. min_capability=cls.get_min_capability())
  103. class GPTQMarlinLinearMethod(LinearMethodBase):
  104. """Linear method for GPTQ Marlin.
  105. Args:
  106. quant_config: The GPTQ Marlin quantization config.
  107. """
  108. def __init__(self, quant_config: GPTQMarlinConfig) -> None:
  109. self.quant_config = quant_config
  110. def create_weights(
  111. self,
  112. layer: torch.nn.Module,
  113. input_size_per_partition: int,
  114. output_partition_sizes: List[int],
  115. input_size: int,
  116. output_size: int,
  117. params_dtype: torch.dtype,
  118. **extra_weight_attrs,
  119. ) -> None:
  120. del output_size
  121. output_size_per_partition = sum(output_partition_sizes)
  122. is_row_parallel = input_size != input_size_per_partition
  123. # Normalize group_size
  124. if self.quant_config.group_size != -1:
  125. group_size = self.quant_config.group_size
  126. else:
  127. group_size = input_size
  128. verify_marlin_supports_shape(
  129. output_size_per_partition=output_size_per_partition,
  130. input_size_per_partition=input_size_per_partition,
  131. input_size=input_size,
  132. group_size=group_size)
  133. # Determine sharding
  134. if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
  135. self.quant_config.group_size,
  136. is_row_parallel):
  137. # By setting scale_dim == None, weight_loader will
  138. # repeat the scales on each GPU in TP>1 case.
  139. scales_and_zp_input_dim = None
  140. scales_and_zp_size = input_size // group_size
  141. else:
  142. # By setting scale_dim == 0, weight_loader will
  143. # shard the scales in TP>1 case.
  144. scales_and_zp_input_dim = 0
  145. scales_and_zp_size = input_size_per_partition // group_size
  146. # Quantized weights
  147. qweight = Parameter(
  148. torch.empty(
  149. input_size_per_partition // self.quant_config.pack_factor,
  150. output_size_per_partition,
  151. dtype=torch.int32,
  152. ),
  153. requires_grad=False,
  154. )
  155. set_weight_attrs(
  156. qweight,
  157. {
  158. **extra_weight_attrs,
  159. "input_dim": 0,
  160. "output_dim": 1,
  161. "packed_dim": 0,
  162. "pack_factor": self.quant_config.pack_factor,
  163. },
  164. )
  165. # Activation order
  166. g_idx = Parameter(
  167. torch.empty(
  168. input_size_per_partition,
  169. dtype=torch.int32,
  170. ),
  171. requires_grad=False,
  172. )
  173. # Ignore warning from fused linear layers such as QKVParallelLinear.
  174. set_weight_attrs(
  175. g_idx,
  176. {
  177. **extra_weight_attrs, "input_dim": 0,
  178. "ignore_warning": True
  179. },
  180. )
  181. # Scales
  182. scales = Parameter(
  183. torch.empty(
  184. scales_and_zp_size,
  185. output_size_per_partition,
  186. dtype=params_dtype,
  187. ),
  188. requires_grad=False,
  189. )
  190. set_weight_attrs(
  191. scales,
  192. {
  193. **extra_weight_attrs,
  194. "input_dim": scales_and_zp_input_dim,
  195. "output_dim": 1,
  196. },
  197. )
  198. # Quantized zero-points
  199. qzeros = Parameter(
  200. torch.empty(
  201. scales_and_zp_size,
  202. output_size_per_partition // self.quant_config.pack_factor,
  203. dtype=torch.int32,
  204. device="meta",
  205. ),
  206. requires_grad=False,
  207. )
  208. set_weight_attrs(
  209. qzeros,
  210. {
  211. **extra_weight_attrs,
  212. "input_dim": scales_and_zp_input_dim,
  213. "output_dim": 1,
  214. "packed_dim": 1,
  215. "pack_factor": self.quant_config.pack_factor,
  216. },
  217. )
  218. layer.register_parameter("qweight", qweight)
  219. layer.register_parameter("g_idx", g_idx)
  220. layer.register_parameter("scales", scales)
  221. layer.register_parameter("qzeros", qzeros)
  222. layer.input_size_per_partition = input_size_per_partition
  223. layer.output_size_per_partition = output_size_per_partition
  224. layer.input_size = input_size
  225. layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act,
  226. is_row_parallel)
  227. # Checkpoints are serialized in AutoGPTQ format, which is different from the
  228. # marlin format. This function is called after the weights are loaded.
  229. # Here, we handle the repacking, including the activation reordering case.
  230. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  231. device = layer.qweight.device
  232. # Allocate marlin workspace
  233. layer.workspace = marlin_make_workspace(
  234. layer.output_size_per_partition, device)
  235. # Handle sorting for activation reordering if needed.
  236. if self.quant_config.desc_act:
  237. g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx)
  238. layer.g_idx_sort_indices = g_idx_sort_indices
  239. replace_tensor(layer, "g_idx", g_idx)
  240. else:
  241. layer.g_idx = marlin_make_empty_g_idx(device)
  242. layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
  243. # No zero-point
  244. layer.zp = marlin_make_empty_g_idx(device)
  245. # Repack weights from autogptq format to marlin format.
  246. marlin_qweight = ops.gptq_marlin_repack(
  247. layer.qweight,
  248. perm=layer.g_idx_sort_indices,
  249. size_k=layer.input_size_per_partition,
  250. size_n=layer.output_size_per_partition,
  251. num_bits=self.quant_config.weight_bits)
  252. replace_tensor(layer, "qweight", marlin_qweight)
  253. # Permute scales from autogptq format to marlin format.
  254. marlin_scales = marlin_permute_scales(
  255. layer.scales,
  256. size_k=(layer.input_size if self.quant_config.desc_act else
  257. layer.input_size_per_partition),
  258. size_n=layer.output_size_per_partition,
  259. group_size=self.quant_config.group_size)
  260. replace_tensor(layer, "scales", marlin_scales)
  261. def apply(
  262. self,
  263. layer: torch.nn.Module,
  264. x: torch.Tensor,
  265. bias: Optional[torch.Tensor] = None,
  266. ) -> torch.Tensor:
  267. return apply_gptq_marlin_linear(
  268. input=x,
  269. weight=layer.qweight,
  270. weight_scale=layer.scales,
  271. weight_zp=layer.zp,
  272. g_idx=layer.g_idx,
  273. g_idx_sort_indices=layer.g_idx_sort_indices,
  274. workspace=layer.workspace,
  275. num_bits=self.quant_config.weight_bits,
  276. output_size_per_partition=layer.output_size_per_partition,
  277. input_size_per_partition=layer.input_size_per_partition,
  278. is_k_full=layer.is_k_full,
  279. bias=bias)