awq.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. from contextlib import suppress
  2. from typing import Any, Dict, List, Optional
  3. import torch
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  6. from aphrodite.modeling.utils import set_weight_attrs
  7. from aphrodite.quantization.base_config import QuantizationConfig
  8. HAS_QUANTS = False
  9. with suppress(ImportError):
  10. from aphrodite._quant_C import quant_ops as ops
  11. HAS_QUANTS = True
  12. class AWQConfig(QuantizationConfig):
  13. """Config class for AWQ.
  14. Reference: https://arxiv.org/abs/2306.00978
  15. """
  16. def __init__(
  17. self,
  18. weight_bits: int,
  19. group_size: int,
  20. zero_point: bool,
  21. ) -> None:
  22. self.weight_bits = weight_bits
  23. self.group_size = group_size
  24. self.zero_point = zero_point
  25. if self.weight_bits != 4:
  26. raise ValueError(
  27. "Currently, only 4-bit weight quantization is supported for "
  28. f"AWQ, but got {self.weight_bits} bits.")
  29. self.pack_factor = 32 // self.weight_bits
  30. def __repr__(self) -> str:
  31. return (f"AWQConfig(weight_bits={self.weight_bits}, "
  32. f"group_size={self.group_size}, "
  33. f"zero_point={self.zero_point})")
  34. def get_name(self) -> str:
  35. return "awq"
  36. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  37. return [torch.half]
  38. def get_min_capability(self) -> int:
  39. # The AWQ kernel only supports Turing or newer GPUs.
  40. return 75
  41. @staticmethod
  42. def get_config_filenames() -> List[str]:
  43. return [
  44. "quant_config.json",
  45. "quantize_config.json",
  46. ]
  47. @classmethod
  48. def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
  49. weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
  50. group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
  51. zero_point = cls.get_from_keys(config, ["zero_point"])
  52. return cls(weight_bits, group_size, zero_point)
  53. def get_quant_method(
  54. self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]:
  55. if isinstance(layer, LinearBase):
  56. return AWQLinearMethod(self)
  57. return None
  58. def get_scaled_act_names(self) -> List[str]:
  59. return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
  60. class AWQLinearMethod(LinearMethodBase):
  61. """Linear method for AWQ.
  62. Args:
  63. quant_config: The AWQ quantization config.
  64. """
  65. def __init__(self, quant_config: AWQConfig):
  66. self.quant_config = quant_config
  67. def create_weights(self, layer: torch.nn.Module,
  68. input_size_per_partition: int,
  69. output_partition_sizes: List[int], input_size: int,
  70. output_size: int, params_dtype: torch.dtype,
  71. **extra_weight_attrs):
  72. if input_size_per_partition % self.quant_config.group_size != 0:
  73. raise ValueError(
  74. "The input size is not aligned with the quantized "
  75. "weight shape. This can be caused by too large "
  76. "tensor parallel size.")
  77. output_size_per_partition = sum(output_partition_sizes)
  78. if output_size_per_partition % self.quant_config.pack_factor != 0:
  79. raise ValueError(
  80. "The output size is not aligned with the quantized "
  81. "weight shape. This can be caused by too large "
  82. "tensor parallel size.")
  83. qweight = Parameter(
  84. torch.empty(
  85. input_size_per_partition,
  86. output_size_per_partition // self.quant_config.pack_factor,
  87. dtype=torch.int32,
  88. ),
  89. requires_grad=False,
  90. )
  91. set_weight_attrs(
  92. qweight, {
  93. "input_dim": 0,
  94. "output_dim": 1,
  95. "packed_dim": 1,
  96. "pack_factor": self.quant_config.pack_factor,
  97. })
  98. qzeros = Parameter(
  99. torch.empty(
  100. input_size_per_partition // self.quant_config.group_size,
  101. output_size_per_partition // self.quant_config.pack_factor,
  102. dtype=torch.int32,
  103. ),
  104. requires_grad=False,
  105. )
  106. set_weight_attrs(
  107. qzeros, {
  108. "input_dim": 0,
  109. "output_dim": 1,
  110. "packed_dim": 1,
  111. "pack_factor": self.quant_config.pack_factor,
  112. })
  113. scales = Parameter(
  114. torch.empty(
  115. input_size_per_partition // self.quant_config.group_size,
  116. output_size_per_partition,
  117. dtype=params_dtype,
  118. ),
  119. requires_grad=False,
  120. )
  121. set_weight_attrs(scales, {
  122. "input_dim": 0,
  123. "output_dim": 1,
  124. })
  125. layer.register_parameter("qweight", qweight)
  126. set_weight_attrs(qweight, extra_weight_attrs)
  127. layer.register_parameter("qzeros", qzeros)
  128. set_weight_attrs(qzeros, extra_weight_attrs)
  129. layer.register_parameter("scales", scales)
  130. set_weight_attrs(scales, extra_weight_attrs)
  131. def apply(self,
  132. layer: torch.nn.Module,
  133. x: torch.Tensor,
  134. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  135. qweight = layer.qweight
  136. scales = layer.scales
  137. qzeros = layer.qzeros
  138. pack_factor = self.quant_config.pack_factor
  139. out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
  140. reshaped_x = x.reshape(-1, x.shape[-1])
  141. # num_tokens >= threshold
  142. FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
  143. if FP16_MATMUL_HEURISTIC_CONDITION:
  144. out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
  145. out = torch.matmul(reshaped_x, out)
  146. else:
  147. out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
  148. pack_factor)
  149. if bias is not None:
  150. out.add_(bias)
  151. return out.reshape(out_shape)