gptq.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import enum
  2. from enum import Enum
  3. from fractions import Fraction
  4. from typing import Any, Dict, List, Optional
  5. import torch
  6. from torch.nn.parameter import Parameter
  7. from aphrodite._quant_C import quant_ops as ops
  8. from aphrodite.modeling.layers.linear import LinearMethodBase, set_weight_attrs
  9. from aphrodite.quantization.base_config import \
  10. QuantizationConfig
  11. class GPTQConfig(QuantizationConfig):
  12. """Config class for GPTQ.
  13. Reference: https://arxiv.org/abs/2210.17323
  14. """
  15. def __init__(
  16. self,
  17. weight_bits: int,
  18. group_size: int,
  19. desc_act: bool,
  20. ) -> None:
  21. self.weight_bits = weight_bits
  22. self.group_size = group_size
  23. self.desc_act = desc_act
  24. self.pack_factor = Fraction(32, self.weight_bits)
  25. if self.weight_bits not in [2, 3, 4, 8]:
  26. raise ValueError(
  27. "Currently, only 2/3/4/8-bit weight quantization is "
  28. f"supported for GPTQ, but got {self.weight_bits} bits.")
  29. def __repr__(self) -> str:
  30. return (f"GPTQConfig(weight_bits={self.weight_bits}, "
  31. f"group_size={self.group_size}, "
  32. f"desc_act={self.desc_act})")
  33. @classmethod
  34. def get_name(cls) -> str:
  35. return "gptq"
  36. @classmethod
  37. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  38. return [torch.half]
  39. @classmethod
  40. # Need to figure it out
  41. def get_min_capability(cls) -> int:
  42. return 60
  43. @classmethod
  44. def get_config_filenames(cls) -> List[str]:
  45. return ["quantize_config.json"]
  46. @classmethod
  47. def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
  48. weight_bits = cls.get_from_keys(config, ["bits"])
  49. group_size = cls.get_from_keys(config, ["group_size"])
  50. desc_act = cls.get_from_keys(config, ["desc_act"])
  51. return cls(weight_bits, group_size, desc_act)
  52. def get_linear_method(self) -> "GPTQLinearMethod":
  53. return GPTQLinearMethod(self)
  54. def get_scaled_act_names(self) -> List[str]:
  55. return []
  56. class ExllamaState(Enum):
  57. UNUSED = enum.auto()
  58. UNINITIALIZED = enum.auto()
  59. READY = enum.auto()
  60. class GPTQLinearMethod(LinearMethodBase):
  61. """Linear method for GPTQ.
  62. Args:
  63. quant_config: The GPTQ quantization config.
  64. """
  65. def __init__(self, quant_config: GPTQConfig):
  66. self.quant_config = quant_config
  67. def create_weights(
  68. self,
  69. layer: torch.nn.Module,
  70. input_size_per_partition: int,
  71. output_partition_sizes: List[int],
  72. input_size: int,
  73. output_size: int,
  74. params_dtype: torch.dtype,
  75. **extra_weight_attrs,
  76. ):
  77. del output_size # Unused.
  78. if input_size_per_partition % self.quant_config.group_size != 0:
  79. raise ValueError(
  80. "The input size is not aligned with the quantized "
  81. "weight shape. This can be caused by too large "
  82. "tensor parallel size.")
  83. output_size_per_partition = sum(output_partition_sizes)
  84. if (output_size_per_partition % self.quant_config.pack_factor.numerator
  85. != 0):
  86. raise ValueError(
  87. "The output size is not aligned with the quantized "
  88. "weight shape. This can be caused by too large "
  89. "tensor parallel size.")
  90. if self.quant_config.group_size != -1:
  91. group_size = self.quant_config.group_size
  92. else:
  93. group_size = input_size
  94. exllama_state = ExllamaState.UNINITIALIZED
  95. scale_and_zero_size = input_size // group_size
  96. scale_and_zero_input_dim = None
  97. if (input_size != input_size_per_partition
  98. and self.quant_config.group_size != -1):
  99. # For act-order models, we cannot use Exllama for row parallel layer
  100. if self.quant_config.desc_act:
  101. exllama_state = ExllamaState.UNUSED
  102. else:
  103. # we need to partition qzeros and scales for exllama kernel
  104. scale_and_zero_size = input_size_per_partition // group_size
  105. scale_and_zero_input_dim = 0
  106. qweight = Parameter(
  107. torch.empty(
  108. input_size_per_partition // self.quant_config.pack_factor,
  109. output_size_per_partition,
  110. dtype=torch.int32,
  111. ),
  112. requires_grad=False,
  113. )
  114. set_weight_attrs(
  115. qweight, {
  116. "input_dim": 0,
  117. "output_dim": 1,
  118. "packed_dim": 0,
  119. "pack_factor": self.quant_config.pack_factor,
  120. })
  121. g_idx = Parameter(
  122. torch.tensor(
  123. [
  124. i // self.quant_config.group_size
  125. for i in range(input_size_per_partition)
  126. ],
  127. dtype=torch.int32,
  128. ),
  129. requires_grad=False,
  130. )
  131. # Ignore warning from fused linear layers such as QKVParallelLinear.
  132. set_weight_attrs(g_idx, {"input_dim": 0, "ignore_warning": True})
  133. qzeros = Parameter(
  134. torch.empty(
  135. scale_and_zero_size,
  136. output_size_per_partition // self.quant_config.pack_factor,
  137. dtype=torch.int32,
  138. ),
  139. requires_grad=False,
  140. )
  141. set_weight_attrs(
  142. qzeros, {
  143. "input_dim": scale_and_zero_input_dim,
  144. "output_dim": 1,
  145. "packed_dim": 1,
  146. "pack_factor": self.quant_config.pack_factor,
  147. })
  148. scales = Parameter(
  149. torch.empty(
  150. scale_and_zero_size,
  151. output_size_per_partition,
  152. dtype=params_dtype,
  153. ),
  154. requires_grad=False,
  155. )
  156. set_weight_attrs(scales, {
  157. "input_dim": scale_and_zero_input_dim,
  158. "output_dim": 1,
  159. })
  160. layer.register_parameter("qweight", qweight)
  161. set_weight_attrs(qweight, extra_weight_attrs)
  162. layer.register_parameter("g_idx", g_idx)
  163. set_weight_attrs(g_idx, extra_weight_attrs)
  164. layer.register_parameter("qzeros", qzeros)
  165. set_weight_attrs(qzeros, extra_weight_attrs)
  166. layer.register_parameter("scales", scales)
  167. set_weight_attrs(scales, extra_weight_attrs)
  168. layer.exllama_state = exllama_state
  169. def apply_weights(self,
  170. layer: torch.nn.Module,
  171. x: torch.Tensor,
  172. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  173. qweight = layer.qweight
  174. out_shape = x.shape[:-1] + (qweight.shape[-1], )
  175. reshaped_x = x.reshape(-1, x.shape[-1])
  176. # exllama needs to shuffle the weight after the weight is loaded
  177. # here we do the shuffle on first forward pass
  178. if layer.exllama_state == ExllamaState.UNINITIALIZED:
  179. if self.quant_config.desc_act:
  180. layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
  181. else:
  182. layer.g_idx.data = torch.empty((0, ),
  183. device=layer.g_idx.device)
  184. layer.exllama_state = ExllamaState.READY
  185. ops.gptq_shuffle(layer.qweight, layer.g_idx,
  186. self.quant_config.weight_bits)
  187. output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
  188. layer.scales, layer.g_idx,
  189. layer.exllama_state == ExllamaState.READY,
  190. self.quant_config.weight_bits)
  191. if bias is not None:
  192. output.add_(bias)
  193. return output.reshape(out_shape)