awq.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from torch.nn.parameter import Parameter
  4. from aphrodite._C import ops
  5. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  6. set_weight_attrs)
  7. from aphrodite.modeling.layers.quantization.base_config import (
  8. QuantizationConfig)
  9. class AWQConfig(QuantizationConfig):
  10. """Config class for AWQ.
  11. Reference: https://arxiv.org/abs/2306.00978
  12. """
  13. def __init__(
  14. self,
  15. weight_bits: int,
  16. group_size: int,
  17. zero_point: bool,
  18. ) -> None:
  19. self.weight_bits = weight_bits
  20. self.group_size = group_size
  21. self.zero_point = zero_point
  22. if self.weight_bits != 4:
  23. raise ValueError(
  24. "Currently, only 4-bit weight quantization is supported for "
  25. f"AWQ, but got {self.weight_bits} bits.")
  26. self.pack_factor = 32 // self.weight_bits
  27. def __repr__(self) -> str:
  28. return (f"AWQConfig(weight_bits={self.weight_bits}, "
  29. f"group_size={self.group_size}, "
  30. f"zero_point={self.zero_point})")
  31. def get_name(self) -> str:
  32. return "awq"
  33. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  34. return [torch.half]
  35. def get_min_capability(self) -> int:
  36. # The AWQ kernel only supports Turing or newer GPUs.
  37. return 75
  38. @staticmethod
  39. def get_config_filenames() -> List[str]:
  40. return [
  41. "quant_config.json",
  42. "quantize_config.json",
  43. ]
  44. @classmethod
  45. def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
  46. weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
  47. group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
  48. zero_point = cls.get_from_keys(config, ["zero_point"])
  49. return cls(weight_bits, group_size, zero_point)
  50. def get_linear_method(self) -> "AWQLinearMethod":
  51. return AWQLinearMethod(self)
  52. def get_scaled_act_names(self) -> List[str]:
  53. return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
  54. def merge_weight(self) -> bool:
  55. return True
  56. def rope_style(self) -> Optional[bool]:
  57. return None
  58. class AWQLinearMethod(LinearMethodBase):
  59. """Linear method for AWQ.
  60. Args:
  61. quant_config: The AWQ quantization config.
  62. """
  63. def __init__(self, quant_config: AWQConfig):
  64. self.quant_config = quant_config
  65. def create_weights(
  66. self,
  67. input_size_per_partition: int,
  68. output_partition_sizes: List[int],
  69. input_size: int,
  70. output_size: int,
  71. params_dtype: torch.dtype,
  72. ) -> Dict[str, Any]:
  73. if input_size_per_partition % self.quant_config.group_size != 0:
  74. raise ValueError(
  75. "The input size is not aligned with the quantized "
  76. "weight shape. This can be caused by too large "
  77. "tensor parallel size.")
  78. output_size_per_partition = sum(output_partition_sizes)
  79. if output_size_per_partition % self.quant_config.pack_factor != 0:
  80. raise ValueError(
  81. "The output size is not aligned with the quantized "
  82. "weight shape. This can be caused by too large "
  83. "tensor parallel size.")
  84. qweight = Parameter(
  85. torch.empty(
  86. input_size_per_partition,
  87. output_size_per_partition // self.quant_config.pack_factor,
  88. dtype=torch.int32,
  89. ),
  90. requires_grad=False,
  91. )
  92. set_weight_attrs(
  93. qweight, {
  94. "input_dim": 0,
  95. "output_dim": 1,
  96. "packed_dim": 1,
  97. "pack_factor": self.quant_config.pack_factor,
  98. })
  99. qzeros = Parameter(
  100. torch.empty(
  101. input_size_per_partition // self.quant_config.group_size,
  102. output_size_per_partition // self.quant_config.pack_factor,
  103. dtype=torch.int32,
  104. ),
  105. requires_grad=False,
  106. )
  107. set_weight_attrs(
  108. qzeros, {
  109. "input_dim": 0,
  110. "output_dim": 1,
  111. "packed_dim": 1,
  112. "pack_factor": self.quant_config.pack_factor,
  113. })
  114. scales = Parameter(
  115. torch.empty(
  116. input_size_per_partition // self.quant_config.group_size,
  117. output_size_per_partition,
  118. dtype=params_dtype,
  119. ),
  120. requires_grad=False,
  121. )
  122. set_weight_attrs(scales, {
  123. "input_dim": 0,
  124. "output_dim": 1,
  125. })
  126. return {
  127. "qweight": qweight,
  128. "qzeros": qzeros,
  129. "scales": scales,
  130. }
  131. def apply_weights(self,
  132. weights: Dict[str, Any],
  133. x: torch.Tensor,
  134. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  135. qweight = weights["qweight"]
  136. qzeros = weights["qzeros"]
  137. scales = weights["scales"]
  138. pack_factor = self.quant_config.pack_factor
  139. out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
  140. reshaped_x = x.reshape(-1, x.shape[-1])
  141. # num_tokens >= threshold
  142. FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
  143. if FP16_MATMUL_HEURISTIC_CONDITION:
  144. out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
  145. out = torch.matmul(reshaped_x, out)
  146. else:
  147. out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
  148. pack_factor)
  149. if bias is not None:
  150. out = out + bias
  151. return out.reshape(out_shape)