gptq.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. import enum
  2. from enum import Enum
  3. from typing import Any, Dict, List, Optional
  4. from fractions import Fraction
  5. from contextlib import suppress
  6. import torch
  7. from torch.nn.parameter import Parameter
  8. from aphrodite.modeling.layers.fused_moe import (fused_moe, fused_topk,
  9. moe_align_block_size)
  10. from aphrodite.modeling.layers.linear import LinearMethodBase, set_weight_attrs
  11. from aphrodite.quantization.base_config import (
  12. QuantizationConfig, )
  13. from aphrodite._C import ops as _C_ops
  14. HAS_QUANTS = False
  15. with suppress(ImportError):
  16. from aphrodite._quant_C import quant_ops as ops
  17. HAS_QUANTS = True
  18. class GPTQConfig(QuantizationConfig):
  19. """Config class for GPTQ.
  20. Reference: https://arxiv.org/abs/2210.17323
  21. """
  22. def __init__(
  23. self,
  24. weight_bits: int,
  25. group_size: int,
  26. desc_act: bool,
  27. ) -> None:
  28. if not HAS_QUANTS:
  29. raise ImportError("Could not find the quantization kernels.")
  30. self.weight_bits = weight_bits
  31. self.group_size = group_size
  32. self.desc_act = desc_act
  33. self.pack_factor = Fraction(32, self.weight_bits)
  34. if self.weight_bits not in [2, 3, 4, 8]:
  35. raise ValueError(
  36. "Currently, only 2/3/4/8-bit weight quantization is supported "
  37. f"for GPTQ, but got {self.weight_bits} bits.")
  38. def __repr__(self) -> str:
  39. return (f"GPTQConfig(weight_bits={self.weight_bits}, "
  40. f"group_size={self.group_size}, "
  41. f"desc_act={self.desc_act})")
  42. @classmethod
  43. def get_name(cls) -> str:
  44. return "gptq"
  45. @classmethod
  46. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  47. return [torch.half]
  48. @classmethod
  49. # Need to figure it out
  50. def get_min_capability(cls) -> int:
  51. return 60
  52. @classmethod
  53. def get_config_filenames(cls) -> List[str]:
  54. return ["quantize_config.json"]
  55. @classmethod
  56. def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
  57. weight_bits = cls.get_from_keys(config, ["bits"])
  58. group_size = cls.get_from_keys(config, ["group_size"])
  59. desc_act = cls.get_from_keys(config, ["desc_act"])
  60. return cls(weight_bits, group_size, desc_act)
  61. def get_linear_method(self) -> "GPTQLinearMethod":
  62. return GPTQLinearMethod(self)
  63. def get_scaled_act_names(self) -> List[str]:
  64. return []
  65. def merge_weight(self) -> bool:
  66. return True
  67. def rope_style(self) -> Optional[bool]:
  68. return None
  69. def quant_vocab(self) -> List[bool]:
  70. return [False, False]
  71. def support_fused_moe(self) -> bool:
  72. return self.weight_bits == 4
  73. class ExllamaState(Enum):
  74. UNUSED = enum.auto()
  75. UNINITIALIZED = enum.auto()
  76. READY = enum.auto()
  77. class GPTQLinearMethod(LinearMethodBase):
  78. """Linear method for GPTQ.
  79. Args:
  80. quant_config: The GPTQ quantization config.
  81. """
  82. def __init__(self, quant_config: GPTQConfig):
  83. self.quant_config = quant_config
  84. def create_weights(
  85. self,
  86. input_size_per_partition: int,
  87. output_partition_sizes: List[int],
  88. input_size: int,
  89. output_size: int,
  90. params_dtype: torch.dtype,
  91. ) -> Dict[str, Any]:
  92. del output_size # Unused.
  93. if input_size_per_partition % self.quant_config.group_size != 0:
  94. raise ValueError(
  95. "The input size is not aligned with the quantized "
  96. "weight shape. This can be caused by too large "
  97. "tensor parallel size.")
  98. output_size_per_partition = sum(output_partition_sizes)
  99. if (output_size_per_partition % self.quant_config.pack_factor.numerator
  100. != 0):
  101. raise ValueError(
  102. "The output size is not aligned with the quantized "
  103. "weight shape. This can be caused by too large "
  104. "tensor parallel size.")
  105. if self.quant_config.group_size != -1:
  106. group_size = self.quant_config.group_size
  107. else:
  108. group_size = input_size
  109. exllama_state = ExllamaState.UNINITIALIZED
  110. scale_and_zero_size = input_size // group_size
  111. scale_and_zero_input_dim = None
  112. if (input_size != input_size_per_partition
  113. and self.quant_config.group_size != -1):
  114. # For act-order models, we cannot use Exllama for row parallel layer
  115. if self.quant_config.desc_act:
  116. exllama_state = ExllamaState.UNUSED
  117. else:
  118. # we need to partition qzeros and scales for exllama kernel
  119. scale_and_zero_size = input_size_per_partition // group_size
  120. scale_and_zero_input_dim = 0
  121. qweight = Parameter(
  122. torch.empty(
  123. input_size_per_partition // self.quant_config.pack_factor,
  124. output_size_per_partition,
  125. dtype=torch.int32,
  126. ),
  127. requires_grad=False,
  128. )
  129. set_weight_attrs(
  130. qweight,
  131. {
  132. "input_dim": 0,
  133. "output_dim": 1,
  134. "packed_dim": 0,
  135. "pack_factor": self.quant_config.pack_factor,
  136. },
  137. )
  138. g_idx = Parameter(
  139. torch.tensor(
  140. [
  141. i // self.quant_config.group_size
  142. for i in range(input_size_per_partition)
  143. ],
  144. dtype=torch.int32,
  145. ),
  146. requires_grad=False,
  147. )
  148. # Ignore warning from fused linear layers such as QKVParallelLinear.
  149. set_weight_attrs(g_idx, {"input_dim": 0, "ignore_warning": True})
  150. qzeros = Parameter(
  151. torch.empty(
  152. scale_and_zero_size,
  153. output_size_per_partition // self.quant_config.pack_factor,
  154. dtype=torch.int32,
  155. ),
  156. requires_grad=False,
  157. )
  158. set_weight_attrs(
  159. qzeros,
  160. {
  161. "input_dim": scale_and_zero_input_dim,
  162. "output_dim": 1,
  163. "packed_dim": 1,
  164. "pack_factor": self.quant_config.pack_factor,
  165. },
  166. )
  167. scales = Parameter(
  168. torch.empty(
  169. scale_and_zero_size,
  170. output_size_per_partition,
  171. dtype=params_dtype,
  172. ),
  173. requires_grad=False,
  174. )
  175. set_weight_attrs(
  176. scales,
  177. {
  178. "input_dim": scale_and_zero_input_dim,
  179. "output_dim": 1,
  180. },
  181. )
  182. return {
  183. "qweight": qweight,
  184. "g_idx": g_idx,
  185. "qzeros": qzeros,
  186. "scales": scales,
  187. "exllama_state": exllama_state,
  188. }
  189. def apply_weights(
  190. self,
  191. weights: Dict[str, Any],
  192. x: torch.Tensor,
  193. bias: Optional[torch.Tensor] = None,
  194. ) -> torch.Tensor:
  195. qweight = weights["qweight"]
  196. out_shape = x.shape[:-1] + (qweight.shape[-1], )
  197. reshaped_x = x.reshape(-1, x.shape[-1])
  198. # exllama needs to shuffle the weight after the weight is loaded
  199. # here we do the shuffle on first forward pass
  200. if weights["exllama_state"] == ExllamaState.UNINITIALIZED:
  201. if self.quant_config.desc_act:
  202. weights["g_idx"] = torch.argsort(weights["g_idx"]).to(
  203. torch.int)
  204. else:
  205. weights["g_idx"] = torch.empty((1, 1), device="meta")
  206. weights["exllama_state"] = ExllamaState.READY
  207. ops.gptq_shuffle(
  208. weights["qweight"],
  209. weights["g_idx"],
  210. self.quant_config.weight_bits,
  211. )
  212. output = ops.gptq_gemm(
  213. reshaped_x,
  214. weights["qweight"],
  215. weights["qzeros"],
  216. weights["scales"],
  217. weights["g_idx"],
  218. weights["exllama_state"] == ExllamaState.READY,
  219. self.quant_config.weight_bits,
  220. )
  221. if bias is not None:
  222. output = output + bias
  223. return output.reshape(out_shape)
  224. def apply_moe_weights(
  225. self,
  226. w1: Dict[str, torch.Tensor],
  227. w2: Dict[str, torch.Tensor],
  228. x: torch.Tensor,
  229. gating_output: torch.Tensor,
  230. topk: int,
  231. renormalize: bool,
  232. ) -> torch.Tensor:
  233. # shuffle weights for exllama
  234. # ignore marlin now which doesn't support fuse moe yet
  235. for w in [w1, w2]:
  236. if w["exllama_state"] == ExllamaState.UNINITIALIZED:
  237. if self.quant_config.desc_act:
  238. w["g_idx"] = torch.argsort(w["g_idx"],
  239. dim=-1).to(torch.int)
  240. else:
  241. w["g_idx"] = torch.empty((1, 1), device="meta")
  242. w["exllama_state"] = ExllamaState.READY
  243. ops.gptq_shuffle(w["qweight"], w["g_idx"],
  244. self.quant_config.weight_bits)
  245. if x.shape[0] >= 128:
  246. dequant_w1 = ops.dequant_gptq(
  247. w1["qweight"],
  248. w1["qzeros"],
  249. w1["scales"],
  250. w1["g_idx"],
  251. self.quant_config.weight_bits,
  252. w1["exllama_state"] == ExllamaState.READY,
  253. ).permute(0, 2, 1)
  254. dequant_w2 = ops.dequant_gptq(
  255. w2["qweight"],
  256. w2["qzeros"],
  257. w2["scales"],
  258. w2["g_idx"],
  259. self.quant_config.weight_bits,
  260. w2["exllama_state"] == ExllamaState.READY,
  261. ).permute(0, 2, 1)
  262. return fused_moe(x, dequant_w1, dequant_w2, gating_output, topk,
  263. renormalize)
  264. topk_weights, topk_ids = fused_topk(gating_output, topk, renormalize)
  265. (
  266. sorted_token_ids,
  267. expert_ids,
  268. num_tokens_post_padded,
  269. ) = moe_align_block_size(topk_ids, 8, w1["qweight"].shape[0])
  270. x = x.view(x.shape[0], 1, *x.shape[1:])
  271. gate_up = ops.group_gptq_gemm(
  272. x,
  273. w1["qweight"],
  274. w1["qzeros"],
  275. w1["scales"],
  276. w1["g_idx"],
  277. topk_weights,
  278. sorted_token_ids,
  279. expert_ids,
  280. num_tokens_post_padded,
  281. False,
  282. w1["exllama_state"] == ExllamaState.READY,
  283. )
  284. out = torch.empty(
  285. (gate_up.shape[:-1] + (gate_up.shape[-1] // 2, )),
  286. dtype=x.dtype,
  287. device=x.device,
  288. )
  289. _C_ops.silu_and_mul(out, gate_up)
  290. out = ops.group_gptq_gemm(
  291. out,
  292. w2["qweight"],
  293. w2["qzeros"],
  294. w2["scales"],
  295. w2["g_idx"],
  296. topk_weights,
  297. sorted_token_ids,
  298. expert_ids,
  299. num_tokens_post_padded,
  300. True,
  301. w2["exllama_state"] == ExllamaState.READY,
  302. )
  303. return torch.sum(out, dim=1)