gptq.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. import enum
  2. from enum import Enum
  3. from typing import Any, Dict, List, Optional
  4. from fractions import Fraction
  5. import torch
  6. from torch.nn.parameter import Parameter
  7. from aphrodite._C import ops
  8. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  9. set_weight_attrs)
  10. from aphrodite.modeling.layers.quantization.base_config import (
  11. QuantizationConfig)
  12. class GPTQConfig(QuantizationConfig):
  13. """Config class for GPTQ.
  14. Reference: https://arxiv.org/abs/2210.17323
  15. """
  16. def __init__(
  17. self,
  18. weight_bits: int,
  19. group_size: int,
  20. desc_act: bool,
  21. ) -> None:
  22. self.weight_bits = weight_bits
  23. self.group_size = group_size
  24. self.desc_act = desc_act
  25. self.pack_factor = Fraction(32, self.weight_bits)
  26. if self.weight_bits not in [2, 3, 4, 8]:
  27. raise ValueError(
  28. "Currently, only 2/3/4/8-bit weight quantization is supported "
  29. f"for GPTQ, but got {self.weight_bits} bits.")
  30. def __repr__(self) -> str:
  31. return (f"GPTQConfig(weight_bits={self.weight_bits}, "
  32. f"group_size={self.group_size}, "
  33. f"desc_act={self.desc_act})")
  34. @classmethod
  35. def get_name(cls) -> str:
  36. return "gptq"
  37. @classmethod
  38. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  39. return [torch.half]
  40. @classmethod
  41. # Need to figure it out
  42. def get_min_capability(cls) -> int:
  43. return 60
  44. @classmethod
  45. def get_config_filenames(cls) -> List[str]:
  46. return ["quantize_config.json"]
  47. @classmethod
  48. def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
  49. weight_bits = cls.get_from_keys(config, ["bits"])
  50. group_size = cls.get_from_keys(config, ["group_size"])
  51. desc_act = cls.get_from_keys(config, ["desc_act"])
  52. return cls(weight_bits, group_size, desc_act)
  53. def get_linear_method(self) -> "GPTQLinearMethod":
  54. return GPTQLinearMethod(self)
  55. def get_scaled_act_names(self) -> List[str]:
  56. return []
  57. def merge_weight(self) -> bool:
  58. return True
  59. def rope_style(self) -> Optional[bool]:
  60. return None
  61. class ExllamaState(Enum):
  62. UNUSED = enum.auto()
  63. UNINITIALIZED = enum.auto()
  64. READY = enum.auto()
  65. class GPTQLinearMethod(LinearMethodBase):
  66. """Linear method for GPTQ.
  67. Args:
  68. quant_config: The GPTQ quantization config.
  69. """
  70. def __init__(self, quant_config: GPTQConfig):
  71. self.quant_config = quant_config
  72. def create_weights(
  73. self,
  74. input_size_per_partition: int,
  75. output_size_per_partition: int,
  76. input_size: int,
  77. output_size: int,
  78. params_dtype: torch.dtype,
  79. ) -> Dict[str, Any]:
  80. del output_size # Unused.
  81. if input_size_per_partition % self.quant_config.group_size != 0:
  82. raise ValueError(
  83. "The input size is not aligned with the quantized "
  84. "weight shape. This can be caused by too large "
  85. "tensor parallel size.")
  86. if (output_size_per_partition % self.quant_config.pack_factor.numerator
  87. != 0):
  88. raise ValueError(
  89. "The output size is not aligned with the quantized "
  90. "weight shape. This can be caused by too large "
  91. "tensor parallel size.")
  92. if self.quant_config.group_size != -1:
  93. group_size = self.quant_config.group_size
  94. else:
  95. group_size = input_size
  96. exllama_state = ExllamaState.UNINITIALIZED
  97. scale_and_zero_size = input_size // group_size
  98. scale_and_zero_input_dim = None
  99. if (input_size != input_size_per_partition
  100. and self.quant_config.group_size != -1):
  101. # For act-order models, we cannot use Exllama for row parallel layer
  102. if self.quant_config.desc_act:
  103. exllama_state = ExllamaState.UNUSED
  104. else:
  105. # we need to partition qzeros and scales for exllama kernel
  106. scale_and_zero_size = input_size_per_partition // group_size
  107. scale_and_zero_input_dim = 0
  108. qweight = Parameter(
  109. torch.empty(
  110. input_size_per_partition // self.quant_config.pack_factor,
  111. output_size_per_partition,
  112. dtype=torch.int32,
  113. ),
  114. requires_grad=False,
  115. )
  116. set_weight_attrs(
  117. qweight, {
  118. "input_dim": 0,
  119. "output_dim": 1,
  120. "packed_dim": 0,
  121. "pack_factor": self.quant_config.pack_factor,
  122. })
  123. g_idx = Parameter(
  124. torch.tensor(
  125. [
  126. i // self.quant_config.group_size
  127. for i in range(input_size_per_partition)
  128. ],
  129. dtype=torch.int32,
  130. ),
  131. requires_grad=False,
  132. )
  133. # Ignore warning from fused linear layers such as QKVParallelLinear.
  134. set_weight_attrs(g_idx, {"input_dim": 0, "ignore_warning": True})
  135. qzeros = Parameter(
  136. torch.empty(
  137. scale_and_zero_size,
  138. output_size_per_partition // self.quant_config.pack_factor,
  139. dtype=torch.int32,
  140. ),
  141. requires_grad=False,
  142. )
  143. set_weight_attrs(
  144. qzeros, {
  145. "input_dim": scale_and_zero_input_dim,
  146. "output_dim": 1,
  147. "packed_dim": 1,
  148. "pack_factor": self.quant_config.pack_factor,
  149. })
  150. scales = Parameter(
  151. torch.empty(
  152. scale_and_zero_size,
  153. output_size_per_partition,
  154. dtype=params_dtype,
  155. ),
  156. requires_grad=False,
  157. )
  158. set_weight_attrs(scales, {
  159. "input_dim": scale_and_zero_input_dim,
  160. "output_dim": 1,
  161. })
  162. return {
  163. "qweight": qweight,
  164. "g_idx": g_idx,
  165. "qzeros": qzeros,
  166. "scales": scales,
  167. "exllama_state": exllama_state,
  168. }
  169. def apply_weights(self,
  170. weights: Dict[str, Any],
  171. x: torch.Tensor,
  172. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  173. qweight = weights["qweight"]
  174. out_shape = x.shape[:-1] + (qweight.shape[-1], )
  175. reshaped_x = x.reshape(-1, x.shape[-1])
  176. # exllama needs to shuffle the weight after the weight is loaded
  177. # here we do the shuffle on first forward pass
  178. if weights["exllama_state"] == ExllamaState.UNINITIALIZED:
  179. if self.quant_config.desc_act:
  180. weights["g_idx"] = torch.argsort(weights["g_idx"]).to(
  181. torch.int)
  182. else:
  183. weights["g_idx"] = torch.empty((1, 1), device="meta")
  184. weights["exllama_state"] = ExllamaState.READY
  185. ops.gptq_shuffle(weights["qweight"], weights["g_idx"],
  186. self.quant_config.weight_bits)
  187. output = ops.gptq_gemm(reshaped_x, weights["qweight"],
  188. weights["qzeros"], weights["scales"],
  189. weights["g_idx"],
  190. weights["exllama_state"] == ExllamaState.READY,
  191. self.quant_config.weight_bits)
  192. if bias is not None:
  193. output = output + bias
  194. return output.reshape(out_shape)