awq.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. from contextlib import suppress
  2. from typing import Any, Dict, List, Optional
  3. import torch
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  6. from aphrodite.modeling.utils import set_weight_attrs
  7. from aphrodite.quantization.base_config import QuantizationConfig
  8. HAS_QUANTS = False
  9. with suppress(ImportError):
  10. from aphrodite._quant_C import quant_ops as ops
  11. HAS_QUANTS = True
  12. class AWQConfig(QuantizationConfig):
  13. """Config class for AWQ.
  14. Reference: https://arxiv.org/abs/2306.00978
  15. """
  16. def __init__(
  17. self,
  18. weight_bits: int,
  19. group_size: int,
  20. zero_point: bool,
  21. ) -> None:
  22. self.weight_bits = weight_bits
  23. self.group_size = group_size
  24. self.zero_point = zero_point
  25. if self.weight_bits != 4:
  26. raise ValueError(
  27. "Currently, only 4-bit weight quantization is supported for "
  28. f"AWQ, but got {self.weight_bits} bits.")
  29. self.pack_factor = 32 // self.weight_bits
  30. def __repr__(self) -> str:
  31. return (f"AWQConfig(weight_bits={self.weight_bits}, "
  32. f"group_size={self.group_size}, "
  33. f"zero_point={self.zero_point})")
  34. def get_name(self) -> str:
  35. return "awq"
  36. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  37. return [torch.half]
  38. def get_min_capability(self) -> int:
  39. # The AWQ kernel only supports Turing or newer GPUs.
  40. return 75
  41. @staticmethod
  42. def get_config_filenames() -> List[str]:
  43. return [
  44. "quant_config.json",
  45. "quantize_config.json",
  46. ]
  47. @classmethod
  48. def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
  49. weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
  50. group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
  51. zero_point = cls.get_from_keys(config, ["zero_point"])
  52. return cls(weight_bits, group_size, zero_point)
  53. def get_quant_method(
  54. self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]:
  55. if isinstance(layer, LinearBase):
  56. return AWQLinearMethod(self)
  57. return None
  58. def get_scaled_act_names(self) -> List[str]:
  59. return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
  60. class AWQLinearMethod(LinearMethodBase):
  61. """Linear method for AWQ.
  62. Args:
  63. quant_config: The AWQ quantization config.
  64. """
  65. def __init__(self, quant_config: AWQConfig):
  66. if not HAS_QUANTS:
  67. raise ImportError("Could not find the quantization kernels.")
  68. self.quant_config = quant_config
  69. def create_weights(self, layer: torch.nn.Module,
  70. input_size_per_partition: int,
  71. output_partition_sizes: List[int], input_size: int,
  72. output_size: int, params_dtype: torch.dtype,
  73. **extra_weight_attrs):
  74. if input_size_per_partition % self.quant_config.group_size != 0:
  75. raise ValueError(
  76. "The input size is not aligned with the quantized "
  77. "weight shape. This can be caused by too large "
  78. "tensor parallel size.")
  79. output_size_per_partition = sum(output_partition_sizes)
  80. if output_size_per_partition % self.quant_config.pack_factor != 0:
  81. raise ValueError(
  82. "The output size is not aligned with the quantized "
  83. "weight shape. This can be caused by too large "
  84. "tensor parallel size.")
  85. qweight = Parameter(
  86. torch.empty(
  87. input_size_per_partition,
  88. output_size_per_partition // self.quant_config.pack_factor,
  89. dtype=torch.int32,
  90. ),
  91. requires_grad=False,
  92. )
  93. set_weight_attrs(
  94. qweight, {
  95. "input_dim": 0,
  96. "output_dim": 1,
  97. "packed_dim": 1,
  98. "pack_factor": self.quant_config.pack_factor,
  99. })
  100. qzeros = Parameter(
  101. torch.empty(
  102. input_size_per_partition // self.quant_config.group_size,
  103. output_size_per_partition // self.quant_config.pack_factor,
  104. dtype=torch.int32,
  105. ),
  106. requires_grad=False,
  107. )
  108. set_weight_attrs(
  109. qzeros, {
  110. "input_dim": 0,
  111. "output_dim": 1,
  112. "packed_dim": 1,
  113. "pack_factor": self.quant_config.pack_factor,
  114. })
  115. scales = Parameter(
  116. torch.empty(
  117. input_size_per_partition // self.quant_config.group_size,
  118. output_size_per_partition,
  119. dtype=params_dtype,
  120. ),
  121. requires_grad=False,
  122. )
  123. set_weight_attrs(scales, {
  124. "input_dim": 0,
  125. "output_dim": 1,
  126. })
  127. layer.register_parameter("qweight", qweight)
  128. set_weight_attrs(qweight, extra_weight_attrs)
  129. layer.register_parameter("qzeros", qzeros)
  130. set_weight_attrs(qzeros, extra_weight_attrs)
  131. layer.register_parameter("scales", scales)
  132. set_weight_attrs(scales, extra_weight_attrs)
  133. def apply(self,
  134. layer: torch.nn.Module,
  135. x: torch.Tensor,
  136. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  137. qweight = layer.qweight
  138. scales = layer.scales
  139. qzeros = layer.qzeros
  140. pack_factor = self.quant_config.pack_factor
  141. out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
  142. reshaped_x = x.reshape(-1, x.shape[-1])
  143. # num_tokens >= threshold
  144. FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
  145. if FP16_MATMUL_HEURISTIC_CONDITION:
  146. out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
  147. out = torch.matmul(reshaped_x, out)
  148. else:
  149. out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
  150. pack_factor)
  151. if bias is not None:
  152. out.add_(bias)
  153. return out.reshape(out_shape)