awq.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. from contextlib import suppress
  2. from typing import Any, Dict, List, Optional
  3. import torch
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.modeling.layers.linear import LinearMethodBase, set_weight_attrs
  6. from aphrodite.quantization.base_config import \
  7. QuantizationConfig
  8. HAS_QUANTS = False
  9. with suppress(ImportError):
  10. from aphrodite._quant_C import quant_ops as ops
  11. HAS_QUANTS = True
  12. class AWQConfig(QuantizationConfig):
  13. """Config class for AWQ.
  14. Reference: https://arxiv.org/abs/2306.00978
  15. """
  16. def __init__(
  17. self,
  18. weight_bits: int,
  19. group_size: int,
  20. zero_point: bool,
  21. ) -> None:
  22. self.weight_bits = weight_bits
  23. self.group_size = group_size
  24. self.zero_point = zero_point
  25. if self.weight_bits != 4:
  26. raise ValueError(
  27. "Currently, only 4-bit weight quantization is supported for "
  28. f"AWQ, but got {self.weight_bits} bits.")
  29. self.pack_factor = 32 // self.weight_bits
  30. def __repr__(self) -> str:
  31. return (f"AWQConfig(weight_bits={self.weight_bits}, "
  32. f"group_size={self.group_size}, "
  33. f"zero_point={self.zero_point})")
  34. def get_name(self) -> str:
  35. return "awq"
  36. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  37. return [torch.half]
  38. def get_min_capability(self) -> int:
  39. # The AWQ kernel only supports Turing or newer GPUs.
  40. return 75
  41. @staticmethod
  42. def get_config_filenames() -> List[str]:
  43. return [
  44. "quant_config.json",
  45. "quantize_config.json",
  46. ]
  47. @classmethod
  48. def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
  49. weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
  50. group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
  51. zero_point = cls.get_from_keys(config, ["zero_point"])
  52. return cls(weight_bits, group_size, zero_point)
  53. def get_linear_method(self) -> "AWQLinearMethod":
  54. return AWQLinearMethod(self)
  55. def get_scaled_act_names(self) -> List[str]:
  56. return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
  57. class AWQLinearMethod(LinearMethodBase):
  58. """Linear method for AWQ.
  59. Args:
  60. quant_config: The AWQ quantization config.
  61. """
  62. def __init__(self, quant_config: AWQConfig):
  63. self.quant_config = quant_config
  64. def create_weights(self, layer: torch.nn.Module,
  65. input_size_per_partition: int,
  66. output_partition_sizes: List[int], input_size: int,
  67. output_size: int, params_dtype: torch.dtype,
  68. **extra_weight_attrs):
  69. if input_size_per_partition % self.quant_config.group_size != 0:
  70. raise ValueError(
  71. "The input size is not aligned with the quantized "
  72. "weight shape. This can be caused by too large "
  73. "tensor parallel size.")
  74. output_size_per_partition = sum(output_partition_sizes)
  75. if output_size_per_partition % self.quant_config.pack_factor != 0:
  76. raise ValueError(
  77. "The output size is not aligned with the quantized "
  78. "weight shape. This can be caused by too large "
  79. "tensor parallel size.")
  80. qweight = Parameter(
  81. torch.empty(
  82. input_size_per_partition,
  83. output_size_per_partition // self.quant_config.pack_factor,
  84. dtype=torch.int32,
  85. ),
  86. requires_grad=False,
  87. )
  88. set_weight_attrs(
  89. qweight, {
  90. "input_dim": 0,
  91. "output_dim": 1,
  92. "packed_dim": 1,
  93. "pack_factor": self.quant_config.pack_factor,
  94. })
  95. qzeros = Parameter(
  96. torch.empty(
  97. input_size_per_partition // self.quant_config.group_size,
  98. output_size_per_partition // self.quant_config.pack_factor,
  99. dtype=torch.int32,
  100. ),
  101. requires_grad=False,
  102. )
  103. set_weight_attrs(
  104. qzeros, {
  105. "input_dim": 0,
  106. "output_dim": 1,
  107. "packed_dim": 1,
  108. "pack_factor": self.quant_config.pack_factor,
  109. })
  110. scales = Parameter(
  111. torch.empty(
  112. input_size_per_partition // self.quant_config.group_size,
  113. output_size_per_partition,
  114. dtype=params_dtype,
  115. ),
  116. requires_grad=False,
  117. )
  118. set_weight_attrs(scales, {
  119. "input_dim": 0,
  120. "output_dim": 1,
  121. })
  122. layer.register_parameter("qweight", qweight)
  123. set_weight_attrs(qweight, extra_weight_attrs)
  124. layer.register_parameter("qzeros", qzeros)
  125. set_weight_attrs(qzeros, extra_weight_attrs)
  126. layer.register_parameter("scales", scales)
  127. set_weight_attrs(scales, extra_weight_attrs)
  128. def apply_weights(self,
  129. layer: torch.nn.Module,
  130. x: torch.Tensor,
  131. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  132. qweight = layer.qweight
  133. scales = layer.scales
  134. qzeros = layer.qzeros
  135. pack_factor = self.quant_config.pack_factor
  136. out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
  137. reshaped_x = x.reshape(-1, x.shape[-1])
  138. # num_tokens >= threshold
  139. FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
  140. if FP16_MATMUL_HEURISTIC_CONDITION:
  141. out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
  142. out = torch.matmul(reshaped_x, out)
  143. else:
  144. out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
  145. pack_factor)
  146. if bias is not None:
  147. out.add_(bias)
  148. return out.reshape(out_shape)