awq.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. from typing import Any, Dict, List, Optional
  2. from contextlib import suppress
  3. import torch
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.modeling.layers.fused_moe import (moe_align_block_size,
  6. fused_moe, fused_topk)
  7. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  8. set_weight_attrs)
  9. from aphrodite.quantization.base_config import (QuantizationConfig)
  10. HAS_QUANTS = False
  11. with suppress(ImportError):
  12. from aphrodite._quant_C import quant_ops as ops
  13. HAS_QUANTS = True
  14. class AWQConfig(QuantizationConfig):
  15. """Config class for AWQ.
  16. Reference: https://arxiv.org/abs/2306.00978
  17. """
  18. def __init__(
  19. self,
  20. weight_bits: int,
  21. group_size: int,
  22. zero_point: bool,
  23. ) -> None:
  24. if not HAS_QUANTS:
  25. raise ImportError("Could not find the quantization kernels.")
  26. self.weight_bits = weight_bits
  27. self.group_size = group_size
  28. self.zero_point = zero_point
  29. if self.weight_bits != 4:
  30. raise ValueError(
  31. "Currently, only 4-bit weight quantization is supported for "
  32. f"AWQ, but got {self.weight_bits} bits.")
  33. self.pack_factor = 32 // self.weight_bits
  34. def __repr__(self) -> str:
  35. return (f"AWQConfig(weight_bits={self.weight_bits}, "
  36. f"group_size={self.group_size}, "
  37. f"zero_point={self.zero_point})")
  38. def get_name(self) -> str:
  39. return "awq"
  40. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  41. return [torch.half]
  42. def get_min_capability(self) -> int:
  43. # The AWQ kernel only supports Turing or newer GPUs.
  44. return 75
  45. @staticmethod
  46. def get_config_filenames() -> List[str]:
  47. return [
  48. "quant_config.json",
  49. "quantize_config.json",
  50. ]
  51. @classmethod
  52. def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
  53. weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
  54. group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
  55. zero_point = cls.get_from_keys(config, ["zero_point"])
  56. return cls(weight_bits, group_size, zero_point)
  57. def get_linear_method(self) -> "AWQLinearMethod":
  58. return AWQLinearMethod(self)
  59. def get_scaled_act_names(self) -> List[str]:
  60. return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
  61. def merge_weight(self) -> bool:
  62. return True
  63. def rope_style(self) -> Optional[bool]:
  64. return None
  65. def quant_vocab(self) -> List[bool]:
  66. return [False, False]
  67. def support_fused_moe(self) -> bool:
  68. return True
  69. class AWQLinearMethod(LinearMethodBase):
  70. """Linear method for AWQ.
  71. Args:
  72. quant_config: The AWQ quantization config.
  73. """
  74. def __init__(self, quant_config: AWQConfig):
  75. self.quant_config = quant_config
  76. def create_weights(self, layer: torch.nn.Module,
  77. input_size_per_partition: int,
  78. output_partition_sizes: List[int], input_size: int,
  79. output_size: int, params_dtype: torch.dtype,
  80. **extra_weight_attrs):
  81. if input_size_per_partition % self.quant_config.group_size != 0:
  82. raise ValueError(
  83. "The input size is not aligned with the quantized "
  84. "weight shape. This can be caused by too large "
  85. "tensor parallel size.")
  86. output_size_per_partition = sum(output_partition_sizes)
  87. if output_size_per_partition % self.quant_config.pack_factor != 0:
  88. raise ValueError(
  89. "The output size is not aligned with the quantized "
  90. "weight shape. This can be caused by too large "
  91. "tensor parallel size.")
  92. qweight = Parameter(
  93. torch.empty(
  94. input_size_per_partition,
  95. output_size_per_partition // self.quant_config.pack_factor,
  96. dtype=torch.int32,
  97. ),
  98. requires_grad=False,
  99. )
  100. set_weight_attrs(
  101. qweight, {
  102. "input_dim": 0,
  103. "output_dim": 1,
  104. "packed_dim": 1,
  105. "pack_factor": self.quant_config.pack_factor,
  106. })
  107. qzeros = Parameter(
  108. torch.empty(
  109. input_size_per_partition // self.quant_config.group_size,
  110. output_size_per_partition // self.quant_config.pack_factor,
  111. dtype=torch.int32,
  112. ),
  113. requires_grad=False,
  114. )
  115. set_weight_attrs(
  116. qzeros, {
  117. "input_dim": 0,
  118. "output_dim": 1,
  119. "packed_dim": 1,
  120. "pack_factor": self.quant_config.pack_factor,
  121. })
  122. scales = Parameter(
  123. torch.empty(
  124. input_size_per_partition // self.quant_config.group_size,
  125. output_size_per_partition,
  126. dtype=params_dtype,
  127. ),
  128. requires_grad=False,
  129. )
  130. set_weight_attrs(scales, {
  131. "input_dim": 0,
  132. "output_dim": 1,
  133. })
  134. layer.register_parameter("qweight", qweight)
  135. set_weight_attrs(qweight, extra_weight_attrs)
  136. layer.register_parameter("qzeros", qzeros)
  137. set_weight_attrs(qzeros, extra_weight_attrs)
  138. layer.register_parameter("scales", scales)
  139. set_weight_attrs(scales, extra_weight_attrs)
  140. def apply_weights(self,
  141. layer: torch.nn.Module,
  142. x: torch.Tensor,
  143. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  144. qweight = layer.qweight
  145. scales = layer.scales
  146. qzeros = layer.qzeros
  147. pack_factor = self.quant_config.pack_factor
  148. out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
  149. reshaped_x = x.reshape(-1, x.shape[-1])
  150. # num_tokens >= threshold
  151. FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
  152. if FP16_MATMUL_HEURISTIC_CONDITION:
  153. out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
  154. out = torch.matmul(reshaped_x, out)
  155. else:
  156. out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
  157. pack_factor)
  158. if bias is not None:
  159. out.add_(bias)
  160. return out.reshape(out_shape)
  161. def apply_moe_weights(self, w1: Dict[str,
  162. torch.Tensor], w2: Dict[str,
  163. torch.Tensor],
  164. x: torch.Tensor, gating_output: torch.Tensor,
  165. topk: int, renormalize: bool) -> torch.Tensor:
  166. FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 1024
  167. if FP16_MATMUL_HEURISTIC_CONDITION:
  168. dequant_w1 = ops.awq_dequantize(w1["qweight"], w1["scales"],
  169. w1["qzeros"], 0, 0,
  170. 0).permute(0, 2, 1)
  171. dequant_w2 = ops.awq_dequantize(w2["qweight"], w2["scales"],
  172. w2["qzeros"], 0, 0,
  173. 0).permute(0, 2, 1)
  174. return fused_moe(x, dequant_w1, dequant_w2, gating_output, topk,
  175. renormalize)
  176. topk_weights, topk_ids = fused_topk(gating_output, topk, renormalize)
  177. (sorted_token_ids, expert_ids,
  178. num_tokens_post_padded) = moe_align_block_size(
  179. topk_ids, 16, w1["qweight"].shape[0])
  180. x = x.view(x.shape[0], 1, *x.shape[1:])
  181. pack_factor = self.quant_config.pack_factor
  182. gate_up = ops.awq_group_gemm(x, w1["qweight"], w1["scales"],
  183. w1["qzeros"], topk_weights,
  184. sorted_token_ids, expert_ids,
  185. num_tokens_post_padded, False,
  186. pack_factor)
  187. out = torch.empty((gate_up.shape[:-1] + (gate_up.shape[-1] // 2, )),
  188. dtype=x.dtype,
  189. device=x.device)
  190. ops.silu_and_mul(out, gate_up)
  191. out = ops.awq_group_gemm(out, w2["qweight"], w2["scales"],
  192. w2["qzeros"], topk_weights, sorted_token_ids,
  193. expert_ids, num_tokens_post_padded, True,
  194. pack_factor)
  195. return torch.sum(out, dim=1)