gptq.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. import enum
  2. from contextlib import suppress
  3. from enum import Enum
  4. from fractions import Fraction
  5. from typing import Any, Dict, List, Optional
  6. import torch
  7. from torch.nn.parameter import Parameter
  8. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  9. from aphrodite.modeling.utils import set_weight_attrs
  10. from aphrodite.quantization.base_config import QuantizationConfig
  11. HAS_QUANTS = False
  12. with suppress(ImportError):
  13. from aphrodite._quant_C import quant_ops as ops
  14. HAS_QUANTS = True
  15. class GPTQConfig(QuantizationConfig):
  16. """Config class for GPTQ.
  17. Reference: https://arxiv.org/abs/2210.17323
  18. """
  19. def __init__(
  20. self,
  21. weight_bits: int,
  22. group_size: int,
  23. desc_act: bool,
  24. ) -> None:
  25. self.weight_bits = weight_bits
  26. self.group_size = group_size
  27. self.desc_act = desc_act
  28. self.pack_factor = Fraction(32, self.weight_bits)
  29. if self.weight_bits not in [2, 3, 4, 8]:
  30. raise ValueError(
  31. "Currently, only 2/3/4/8-bit weight quantization is "
  32. f"supported for GPTQ, but got {self.weight_bits} bits.")
  33. def __repr__(self) -> str:
  34. return (f"GPTQConfig(weight_bits={self.weight_bits}, "
  35. f"group_size={self.group_size}, "
  36. f"desc_act={self.desc_act})")
  37. @classmethod
  38. def get_name(cls) -> str:
  39. return "gptq"
  40. @classmethod
  41. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  42. return [torch.half]
  43. @classmethod
  44. # Need to figure it out
  45. def get_min_capability(cls) -> int:
  46. return 60
  47. @classmethod
  48. def get_config_filenames(cls) -> List[str]:
  49. return ["quantize_config.json"]
  50. @classmethod
  51. def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
  52. weight_bits = cls.get_from_keys(config, ["bits"])
  53. group_size = cls.get_from_keys(config, ["group_size"])
  54. desc_act = cls.get_from_keys(config, ["desc_act"])
  55. return cls(weight_bits, group_size, desc_act)
  56. def get_quant_method(
  57. self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]:
  58. if isinstance(layer, LinearBase):
  59. return GPTQLinearMethod(self)
  60. return None
  61. def get_scaled_act_names(self) -> List[str]:
  62. return []
  63. class ExllamaState(Enum):
  64. UNUSED = enum.auto()
  65. UNINITIALIZED = enum.auto()
  66. READY = enum.auto()
  67. class GPTQLinearMethod(LinearMethodBase):
  68. """Linear method for GPTQ.
  69. Args:
  70. quant_config: The GPTQ quantization config.
  71. """
  72. def __init__(self, quant_config: GPTQConfig):
  73. if not HAS_QUANTS:
  74. raise ImportError("Could not find the quantization kernels.")
  75. self.quant_config = quant_config
  76. def create_weights(
  77. self,
  78. layer: torch.nn.Module,
  79. input_size_per_partition: int,
  80. output_partition_sizes: List[int],
  81. input_size: int,
  82. output_size: int,
  83. params_dtype: torch.dtype,
  84. **extra_weight_attrs,
  85. ):
  86. del output_size # Unused.
  87. if input_size_per_partition % self.quant_config.group_size != 0:
  88. raise ValueError(
  89. "The input size is not aligned with the quantized "
  90. "weight shape. This can be caused by too large "
  91. "tensor parallel size.")
  92. output_size_per_partition = sum(output_partition_sizes)
  93. if (output_size_per_partition % self.quant_config.pack_factor.numerator
  94. != 0):
  95. raise ValueError(
  96. "The output size is not aligned with the quantized "
  97. "weight shape. This can be caused by too large "
  98. "tensor parallel size.")
  99. if self.quant_config.group_size != -1:
  100. group_size = self.quant_config.group_size
  101. else:
  102. group_size = input_size
  103. exllama_state = ExllamaState.UNINITIALIZED
  104. scale_and_zero_size = input_size // group_size
  105. scale_and_zero_input_dim = None
  106. if (input_size != input_size_per_partition
  107. and self.quant_config.group_size != -1):
  108. # For act-order models, we cannot use Exllama for row parallel layer
  109. if self.quant_config.desc_act:
  110. exllama_state = ExllamaState.UNUSED
  111. else:
  112. # we need to partition qzeros and scales for exllama kernel
  113. scale_and_zero_size = input_size_per_partition // group_size
  114. scale_and_zero_input_dim = 0
  115. qweight = Parameter(
  116. torch.empty(
  117. input_size_per_partition // self.quant_config.pack_factor,
  118. output_size_per_partition,
  119. dtype=torch.int32,
  120. ),
  121. requires_grad=False,
  122. )
  123. set_weight_attrs(
  124. qweight, {
  125. "input_dim": 0,
  126. "output_dim": 1,
  127. "packed_dim": 0,
  128. "pack_factor": self.quant_config.pack_factor,
  129. })
  130. g_idx = Parameter(
  131. torch.tensor(
  132. [
  133. i // self.quant_config.group_size
  134. for i in range(input_size_per_partition)
  135. ],
  136. dtype=torch.int32,
  137. ),
  138. requires_grad=False,
  139. )
  140. # Ignore warning from fused linear layers such as QKVParallelLinear.
  141. set_weight_attrs(g_idx, {"input_dim": 0, "ignore_warning": True})
  142. qzeros = Parameter(
  143. torch.empty(
  144. scale_and_zero_size,
  145. output_size_per_partition // self.quant_config.pack_factor,
  146. dtype=torch.int32,
  147. ),
  148. requires_grad=False,
  149. )
  150. set_weight_attrs(
  151. qzeros, {
  152. "input_dim": scale_and_zero_input_dim,
  153. "output_dim": 1,
  154. "packed_dim": 1,
  155. "pack_factor": self.quant_config.pack_factor,
  156. })
  157. scales = Parameter(
  158. torch.empty(
  159. scale_and_zero_size,
  160. output_size_per_partition,
  161. dtype=params_dtype,
  162. ),
  163. requires_grad=False,
  164. )
  165. set_weight_attrs(scales, {
  166. "input_dim": scale_and_zero_input_dim,
  167. "output_dim": 1,
  168. })
  169. layer.register_parameter("qweight", qweight)
  170. set_weight_attrs(qweight, extra_weight_attrs)
  171. layer.register_parameter("g_idx", g_idx)
  172. set_weight_attrs(g_idx, extra_weight_attrs)
  173. layer.register_parameter("qzeros", qzeros)
  174. set_weight_attrs(qzeros, extra_weight_attrs)
  175. layer.register_parameter("scales", scales)
  176. set_weight_attrs(scales, extra_weight_attrs)
  177. layer.exllama_state = exllama_state
  178. def apply(self,
  179. layer: torch.nn.Module,
  180. x: torch.Tensor,
  181. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  182. qweight = layer.qweight
  183. out_shape = x.shape[:-1] + (qweight.shape[-1], )
  184. reshaped_x = x.reshape(-1, x.shape[-1])
  185. # exllama needs to shuffle the weight after the weight is loaded
  186. # here we do the shuffle on first forward pass
  187. if layer.exllama_state == ExllamaState.UNINITIALIZED:
  188. if self.quant_config.desc_act:
  189. layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
  190. else:
  191. layer.g_idx.data = torch.empty((0, ),
  192. device=layer.g_idx.device)
  193. layer.exllama_state = ExllamaState.READY
  194. ops.gptq_shuffle(layer.qweight, layer.g_idx,
  195. self.quant_config.weight_bits)
  196. output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
  197. layer.scales, layer.g_idx,
  198. layer.exllama_state == ExllamaState.READY,
  199. self.quant_config.weight_bits)
  200. if bias is not None:
  201. output.add_(bias)
  202. return output.reshape(out_shape)