gptq.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. import enum
  2. from enum import Enum
  3. from fractions import Fraction
  4. from typing import Any, Dict, List, Optional
  5. import torch
  6. from torch.nn.parameter import Parameter
  7. from aphrodite import _custom_ops as ops
  8. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  9. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  10. from aphrodite.modeling.parameter import (ChannelQuantScaleParameter,
  11. GroupQuantScaleParameter,
  12. PackedAphroditeParameter,
  13. PackedColumnParameter,
  14. RowAphroditeParameter)
  15. from aphrodite.quantization.base_config import QuantizationConfig
  16. class GPTQConfig(QuantizationConfig):
  17. """Config class for GPTQ.
  18. Reference: https://arxiv.org/abs/2210.17323
  19. """
  20. def __init__(
  21. self,
  22. weight_bits: int,
  23. group_size: int,
  24. desc_act: bool,
  25. lm_head_quantized: bool,
  26. ) -> None:
  27. self.weight_bits = weight_bits
  28. self.group_size = group_size
  29. self.desc_act = desc_act
  30. self.lm_head_quantized = lm_head_quantized
  31. self.pack_factor = Fraction(32, self.weight_bits)
  32. if self.weight_bits not in [2, 3, 4, 8]:
  33. raise ValueError(
  34. "Currently, only 2/3/4/8-bit weight quantization is "
  35. f"supported for GPTQ, but got {self.weight_bits} bits.")
  36. def __repr__(self) -> str:
  37. return (f"GPTQConfig(weight_bits={self.weight_bits}, "
  38. f"group_size={self.group_size}, "
  39. f"desc_act={self.desc_act}),"
  40. f"lm_head_quantized={self.lm_head_quantized}")
  41. @classmethod
  42. def get_name(cls) -> str:
  43. return "gptq"
  44. @classmethod
  45. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  46. return [torch.half]
  47. @classmethod
  48. # Need to figure it out
  49. def get_min_capability(cls) -> int:
  50. return 60
  51. @classmethod
  52. def get_config_filenames(cls) -> List[str]:
  53. return ["quantize_config.json"]
  54. @classmethod
  55. def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
  56. weight_bits = cls.get_from_keys(config, ["bits"])
  57. group_size = cls.get_from_keys(config, ["group_size"])
  58. desc_act = cls.get_from_keys(config, ["desc_act"])
  59. lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
  60. default=False)
  61. return cls(weight_bits, group_size, desc_act, lm_head_quantized)
  62. def get_quant_method(self, layer: torch.nn.Module,
  63. prefix: str) -> Optional["GPTQLinearMethod"]:
  64. if (isinstance(layer, LinearBase) or
  65. (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
  66. return GPTQLinearMethod(self)
  67. return None
  68. def get_scaled_act_names(self) -> List[str]:
  69. return []
  70. class ExllamaState(Enum):
  71. UNUSED = enum.auto()
  72. UNINITIALIZED = enum.auto()
  73. READY = enum.auto()
  74. class GPTQLinearMethod(LinearMethodBase):
  75. """Linear method for GPTQ.
  76. Args:
  77. quant_config: The GPTQ quantization config.
  78. """
  79. def __init__(self, quant_config: GPTQConfig):
  80. self.quant_config = quant_config
  81. def create_weights(
  82. self,
  83. layer: torch.nn.Module,
  84. input_size_per_partition: int,
  85. output_partition_sizes: List[int],
  86. input_size: int,
  87. output_size: int,
  88. params_dtype: torch.dtype,
  89. **extra_weight_attrs,
  90. ):
  91. del output_size # Unused.
  92. weight_loader = extra_weight_attrs.get("weight_loader")
  93. if input_size_per_partition % self.quant_config.group_size != 0:
  94. raise ValueError(
  95. "The input size is not aligned with the quantized "
  96. "weight shape. This can be caused by too large "
  97. "tensor parallel size.")
  98. output_size_per_partition = sum(output_partition_sizes)
  99. if (output_size_per_partition % self.quant_config.pack_factor.numerator
  100. != 0):
  101. raise ValueError(
  102. "The output size is not aligned with the quantized "
  103. "weight shape. This can be caused by too large "
  104. "tensor parallel size.")
  105. if self.quant_config.group_size != -1:
  106. group_size = self.quant_config.group_size
  107. else:
  108. group_size = input_size
  109. exllama_state = ExllamaState.UNINITIALIZED
  110. scale_and_zero_size = input_size // group_size
  111. scale_and_zero_input_dim = None
  112. if (input_size != input_size_per_partition
  113. and self.quant_config.group_size != -1):
  114. # For act-order models, we cannot use Exllama for row parallel layer
  115. if self.quant_config.desc_act:
  116. exllama_state = ExllamaState.UNUSED
  117. else:
  118. # we need to partition qzeros and scales for exllama kernel
  119. scale_and_zero_size = input_size_per_partition // group_size
  120. scale_and_zero_input_dim = 0
  121. qweight = PackedAphroditeParameter(
  122. data=torch.empty(
  123. input_size_per_partition // self.quant_config.pack_factor,
  124. output_size_per_partition,
  125. dtype=torch.int32,
  126. ),
  127. input_dim=0,
  128. output_dim=1,
  129. packed_dim=0,
  130. packed_factor=self.quant_config.pack_factor,
  131. weight_loader=weight_loader)
  132. g_idx = RowAphroditeParameter(data=torch.tensor(
  133. [
  134. i // self.quant_config.group_size
  135. for i in range(input_size_per_partition)
  136. ],
  137. dtype=torch.int32,
  138. ),
  139. input_dim=0,
  140. weight_loader=weight_loader)
  141. qzeros_args = {
  142. "data":
  143. torch.empty(
  144. scale_and_zero_size,
  145. output_size_per_partition // self.quant_config.pack_factor,
  146. dtype=torch.int32,
  147. ),
  148. "weight_loader":
  149. weight_loader
  150. }
  151. weight_scale_args = {
  152. "data":
  153. torch.empty(
  154. scale_and_zero_size,
  155. output_size_per_partition,
  156. dtype=params_dtype,
  157. ),
  158. "weight_loader":
  159. weight_loader
  160. }
  161. if scale_and_zero_input_dim is None:
  162. scales = ChannelQuantScaleParameter(output_dim=1,
  163. **weight_scale_args)
  164. qzeros = PackedColumnParameter(
  165. output_dim=1,
  166. packed_dim=1,
  167. packed_factor=self.quant_config.pack_factor,
  168. **qzeros_args)
  169. else:
  170. scales = GroupQuantScaleParameter(output_dim=1,
  171. input_dim=0,
  172. **weight_scale_args)
  173. qzeros = PackedColumnParameter(
  174. input_dim=0,
  175. output_dim=1,
  176. packed_dim=1,
  177. packed_factor=self.quant_config.pack_factor,
  178. **qzeros_args)
  179. layer.register_parameter("qweight", qweight)
  180. layer.register_parameter("g_idx", g_idx)
  181. layer.register_parameter("qzeros", qzeros)
  182. layer.register_parameter("scales", scales)
  183. layer.exllama_state = exllama_state
  184. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  185. # for torch.compile
  186. layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
  187. layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
  188. layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
  189. layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
  190. # exllama needs to shuffle the weight after the weight is loaded
  191. # here we do the shuffle on first forward pass
  192. if layer.exllama_state == ExllamaState.UNINITIALIZED:
  193. if self.quant_config.desc_act:
  194. layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
  195. else:
  196. layer.g_idx.data = torch.empty((0, ),
  197. dtype=torch.int,
  198. device=layer.g_idx.device)
  199. layer.exllama_state = ExllamaState.READY
  200. ops.gptq_shuffle(layer.qweight, layer.g_idx,
  201. self.quant_config.weight_bits)
  202. def apply(self,
  203. layer: torch.nn.Module,
  204. x: torch.Tensor,
  205. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  206. out_shape = x.shape[:-1] + (layer.qweight.shape[-1], )
  207. reshaped_x = x.reshape(-1, x.shape[-1])
  208. output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
  209. layer.scales, layer.g_idx,
  210. layer.exllama_state == ExllamaState.READY,
  211. self.quant_config.weight_bits)
  212. if bias is not None:
  213. output.add_(bias)
  214. return output.reshape(out_shape)