gptq.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. import enum
  2. from enum import Enum
  3. from fractions import Fraction
  4. from typing import Any, Dict, List, Optional
  5. import torch
  6. from torch.nn.parameter import Parameter
  7. from aphrodite import _custom_ops as ops
  8. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  9. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  10. from aphrodite.modeling.utils import set_weight_attrs
  11. from aphrodite.quantization.base_config import QuantizationConfig
  12. class GPTQConfig(QuantizationConfig):
  13. """Config class for GPTQ.
  14. Reference: https://arxiv.org/abs/2210.17323
  15. """
  16. def __init__(
  17. self,
  18. weight_bits: int,
  19. group_size: int,
  20. desc_act: bool,
  21. lm_head_quantized: bool,
  22. ) -> None:
  23. self.weight_bits = weight_bits
  24. self.group_size = group_size
  25. self.desc_act = desc_act
  26. self.lm_head_quantized = lm_head_quantized
  27. self.pack_factor = Fraction(32, self.weight_bits)
  28. if self.weight_bits not in [2, 3, 4, 8]:
  29. raise ValueError(
  30. "Currently, only 2/3/4/8-bit weight quantization is "
  31. f"supported for GPTQ, but got {self.weight_bits} bits.")
  32. def __repr__(self) -> str:
  33. return (f"GPTQConfig(weight_bits={self.weight_bits}, "
  34. f"group_size={self.group_size}, "
  35. f"desc_act={self.desc_act}),"
  36. f"lm_head_quantized={self.lm_head_quantized}")
  37. @classmethod
  38. def get_name(cls) -> str:
  39. return "gptq"
  40. @classmethod
  41. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  42. return [torch.half]
  43. @classmethod
  44. # Need to figure it out
  45. def get_min_capability(cls) -> int:
  46. return 60
  47. @classmethod
  48. def get_config_filenames(cls) -> List[str]:
  49. return ["quantize_config.json"]
  50. @classmethod
  51. def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
  52. weight_bits = cls.get_from_keys(config, ["bits"])
  53. group_size = cls.get_from_keys(config, ["group_size"])
  54. desc_act = cls.get_from_keys(config, ["desc_act"])
  55. lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
  56. default=False)
  57. return cls(weight_bits, group_size, desc_act, lm_head_quantized)
  58. def get_quant_method(self, layer: torch.nn.Module,
  59. prefix: str) -> Optional["GPTQLinearMethod"]:
  60. if (isinstance(layer, LinearBase) or
  61. (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
  62. return GPTQLinearMethod(self)
  63. return None
  64. def get_scaled_act_names(self) -> List[str]:
  65. return []
  66. class ExllamaState(Enum):
  67. UNUSED = enum.auto()
  68. UNINITIALIZED = enum.auto()
  69. READY = enum.auto()
  70. class GPTQLinearMethod(LinearMethodBase):
  71. """Linear method for GPTQ.
  72. Args:
  73. quant_config: The GPTQ quantization config.
  74. """
  75. def __init__(self, quant_config: GPTQConfig):
  76. self.quant_config = quant_config
  77. def create_weights(
  78. self,
  79. layer: torch.nn.Module,
  80. input_size_per_partition: int,
  81. output_partition_sizes: List[int],
  82. input_size: int,
  83. output_size: int,
  84. params_dtype: torch.dtype,
  85. **extra_weight_attrs,
  86. ):
  87. del output_size # Unused.
  88. if input_size_per_partition % self.quant_config.group_size != 0:
  89. raise ValueError(
  90. "The input size is not aligned with the quantized "
  91. "weight shape. This can be caused by too large "
  92. "tensor parallel size.")
  93. output_size_per_partition = sum(output_partition_sizes)
  94. if (output_size_per_partition % self.quant_config.pack_factor.numerator
  95. != 0):
  96. raise ValueError(
  97. "The output size is not aligned with the quantized "
  98. "weight shape. This can be caused by too large "
  99. "tensor parallel size.")
  100. if self.quant_config.group_size != -1:
  101. group_size = self.quant_config.group_size
  102. else:
  103. group_size = input_size
  104. exllama_state = ExllamaState.UNINITIALIZED
  105. scale_and_zero_size = input_size // group_size
  106. scale_and_zero_input_dim = None
  107. if (input_size != input_size_per_partition
  108. and self.quant_config.group_size != -1):
  109. # For act-order models, we cannot use Exllama for row parallel layer
  110. if self.quant_config.desc_act:
  111. exllama_state = ExllamaState.UNUSED
  112. else:
  113. # we need to partition qzeros and scales for exllama kernel
  114. scale_and_zero_size = input_size_per_partition // group_size
  115. scale_and_zero_input_dim = 0
  116. qweight = Parameter(
  117. torch.empty(
  118. input_size_per_partition // self.quant_config.pack_factor,
  119. output_size_per_partition,
  120. dtype=torch.int32,
  121. ),
  122. requires_grad=False,
  123. )
  124. set_weight_attrs(
  125. qweight, {
  126. "input_dim": 0,
  127. "output_dim": 1,
  128. "packed_dim": 0,
  129. "pack_factor": self.quant_config.pack_factor,
  130. })
  131. g_idx = Parameter(
  132. torch.tensor(
  133. [
  134. i // self.quant_config.group_size
  135. for i in range(input_size_per_partition)
  136. ],
  137. dtype=torch.int32,
  138. ),
  139. requires_grad=False,
  140. )
  141. # Ignore warning from fused linear layers such as QKVParallelLinear.
  142. set_weight_attrs(g_idx, {"input_dim": 0, "ignore_warning": True})
  143. qzeros = Parameter(
  144. torch.empty(
  145. scale_and_zero_size,
  146. output_size_per_partition // self.quant_config.pack_factor,
  147. dtype=torch.int32,
  148. ),
  149. requires_grad=False,
  150. )
  151. set_weight_attrs(
  152. qzeros, {
  153. "input_dim": scale_and_zero_input_dim,
  154. "output_dim": 1,
  155. "packed_dim": 1,
  156. "pack_factor": self.quant_config.pack_factor,
  157. })
  158. scales = Parameter(
  159. torch.empty(
  160. scale_and_zero_size,
  161. output_size_per_partition,
  162. dtype=params_dtype,
  163. ),
  164. requires_grad=False,
  165. )
  166. set_weight_attrs(scales, {
  167. "input_dim": scale_and_zero_input_dim,
  168. "output_dim": 1,
  169. })
  170. layer.register_parameter("qweight", qweight)
  171. set_weight_attrs(qweight, extra_weight_attrs)
  172. layer.register_parameter("g_idx", g_idx)
  173. set_weight_attrs(g_idx, extra_weight_attrs)
  174. layer.register_parameter("qzeros", qzeros)
  175. set_weight_attrs(qzeros, extra_weight_attrs)
  176. layer.register_parameter("scales", scales)
  177. set_weight_attrs(scales, extra_weight_attrs)
  178. layer.exllama_state = exllama_state
  179. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  180. # exllama needs to shuffle the weight after the weight is loaded
  181. # here we do the shuffle on first forward pass
  182. if layer.exllama_state == ExllamaState.UNINITIALIZED:
  183. if self.quant_config.desc_act:
  184. layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
  185. else:
  186. layer.g_idx.data = torch.empty((0, ),
  187. dtype=torch.int,
  188. device=layer.g_idx.device)
  189. layer.exllama_state = ExllamaState.READY
  190. ops.gptq_shuffle(layer.qweight, layer.g_idx,
  191. self.quant_config.weight_bits)
  192. def apply(self,
  193. layer: torch.nn.Module,
  194. x: torch.Tensor,
  195. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  196. out_shape = x.shape[:-1] + (layer.qweight.shape[-1], )
  197. reshaped_x = x.reshape(-1, x.shape[-1])
  198. output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
  199. layer.scales, layer.g_idx,
  200. layer.exllama_state == ExllamaState.READY,
  201. self.quant_config.weight_bits)
  202. if bias is not None:
  203. output.add_(bias)
  204. return output.reshape(out_shape)