1
0

awq.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. from typing import Any, Dict, List, Optional
  2. from contextlib import suppress
  3. import torch
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.modeling.layers.fused_moe import (moe_align_block_size,
  6. fused_moe, fused_topk)
  7. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  8. set_weight_attrs)
  9. from aphrodite.quantization.base_config import (QuantizationConfig)
  10. from aphrodite._C import ops as _C_ops
  11. HAS_QUANTS = False
  12. with suppress(ImportError):
  13. from aphrodite._quant_C import quant_ops as ops
  14. HAS_QUANTS = True
  15. class AWQConfig(QuantizationConfig):
  16. """Config class for AWQ.
  17. Reference: https://arxiv.org/abs/2306.00978
  18. """
  19. def __init__(
  20. self,
  21. weight_bits: int,
  22. group_size: int,
  23. zero_point: bool,
  24. ) -> None:
  25. if not HAS_QUANTS:
  26. raise ImportError("Could not find the quantization kernels.")
  27. self.weight_bits = weight_bits
  28. self.group_size = group_size
  29. self.zero_point = zero_point
  30. if self.weight_bits != 4:
  31. raise ValueError(
  32. "Currently, only 4-bit weight quantization is supported for "
  33. f"AWQ, but got {self.weight_bits} bits.")
  34. self.pack_factor = 32 // self.weight_bits
  35. def __repr__(self) -> str:
  36. return (f"AWQConfig(weight_bits={self.weight_bits}, "
  37. f"group_size={self.group_size}, "
  38. f"zero_point={self.zero_point})")
  39. def get_name(self) -> str:
  40. return "awq"
  41. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  42. return [torch.half]
  43. def get_min_capability(self) -> int:
  44. # The AWQ kernel only supports Turing or newer GPUs.
  45. return 75
  46. @staticmethod
  47. def get_config_filenames() -> List[str]:
  48. return [
  49. "quant_config.json",
  50. "quantize_config.json",
  51. ]
  52. @classmethod
  53. def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
  54. weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
  55. group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
  56. zero_point = cls.get_from_keys(config, ["zero_point"])
  57. return cls(weight_bits, group_size, zero_point)
  58. def get_linear_method(self) -> "AWQLinearMethod":
  59. return AWQLinearMethod(self)
  60. def get_scaled_act_names(self) -> List[str]:
  61. return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
  62. def merge_weight(self) -> bool:
  63. return True
  64. def rope_style(self) -> Optional[bool]:
  65. return None
  66. def quant_vocab(self) -> List[bool]:
  67. return [False, False]
  68. def support_fused_moe(self) -> bool:
  69. return True
  70. class AWQLinearMethod(LinearMethodBase):
  71. """Linear method for AWQ.
  72. Args:
  73. quant_config: The AWQ quantization config.
  74. """
  75. def __init__(self, quant_config: AWQConfig):
  76. self.quant_config = quant_config
  77. def create_weights(
  78. self,
  79. input_size_per_partition: int,
  80. output_partition_sizes: List[int],
  81. input_size: int,
  82. output_size: int,
  83. params_dtype: torch.dtype,
  84. ) -> Dict[str, Any]:
  85. if input_size_per_partition % self.quant_config.group_size != 0:
  86. raise ValueError(
  87. "The input size is not aligned with the quantized "
  88. "weight shape. This can be caused by too large "
  89. "tensor parallel size.")
  90. output_size_per_partition = sum(output_partition_sizes)
  91. if output_size_per_partition % self.quant_config.pack_factor != 0:
  92. raise ValueError(
  93. "The output size is not aligned with the quantized "
  94. "weight shape. This can be caused by too large "
  95. "tensor parallel size.")
  96. qweight = Parameter(
  97. torch.empty(
  98. input_size_per_partition,
  99. output_size_per_partition // self.quant_config.pack_factor,
  100. dtype=torch.int32,
  101. ),
  102. requires_grad=False,
  103. )
  104. set_weight_attrs(
  105. qweight, {
  106. "input_dim": 0,
  107. "output_dim": 1,
  108. "packed_dim": 1,
  109. "pack_factor": self.quant_config.pack_factor,
  110. })
  111. qzeros = Parameter(
  112. torch.empty(
  113. input_size_per_partition // self.quant_config.group_size,
  114. output_size_per_partition // self.quant_config.pack_factor,
  115. dtype=torch.int32,
  116. ),
  117. requires_grad=False,
  118. )
  119. set_weight_attrs(
  120. qzeros, {
  121. "input_dim": 0,
  122. "output_dim": 1,
  123. "packed_dim": 1,
  124. "pack_factor": self.quant_config.pack_factor,
  125. })
  126. scales = Parameter(
  127. torch.empty(
  128. input_size_per_partition // self.quant_config.group_size,
  129. output_size_per_partition,
  130. dtype=params_dtype,
  131. ),
  132. requires_grad=False,
  133. )
  134. set_weight_attrs(scales, {
  135. "input_dim": 0,
  136. "output_dim": 1,
  137. })
  138. return {
  139. "qweight": qweight,
  140. "qzeros": qzeros,
  141. "scales": scales,
  142. }
  143. def apply_weights(self,
  144. weights: Dict[str, Any],
  145. x: torch.Tensor,
  146. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  147. qweight = weights["qweight"]
  148. qzeros = weights["qzeros"]
  149. scales = weights["scales"]
  150. pack_factor = self.quant_config.pack_factor
  151. out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
  152. reshaped_x = x.reshape(-1, x.shape[-1])
  153. # num_tokens >= threshold
  154. FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
  155. if FP16_MATMUL_HEURISTIC_CONDITION:
  156. out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
  157. out = torch.matmul(reshaped_x, out)
  158. else:
  159. out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
  160. pack_factor)
  161. if bias is not None:
  162. out = out + bias
  163. return out.reshape(out_shape)
  164. def apply_moe_weights(self, w1: Dict[str,
  165. torch.Tensor], w2: Dict[str,
  166. torch.Tensor],
  167. x: torch.Tensor, gating_output: torch.Tensor,
  168. topk: int, renormalize: bool) -> torch.Tensor:
  169. FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 1024
  170. if FP16_MATMUL_HEURISTIC_CONDITION:
  171. dequant_w1 = ops.awq_dequantize(w1["qweight"], w1["scales"],
  172. w1["qzeros"], 0, 0,
  173. 0).permute(0, 2, 1)
  174. dequant_w2 = ops.awq_dequantize(w2["qweight"], w2["scales"],
  175. w2["qzeros"], 0, 0,
  176. 0).permute(0, 2, 1)
  177. return fused_moe(x, dequant_w1, dequant_w2, gating_output, topk,
  178. renormalize)
  179. topk_weights, topk_ids = fused_topk(gating_output, topk, renormalize)
  180. (sorted_token_ids, expert_ids,
  181. num_tokens_post_padded) = moe_align_block_size(
  182. topk_ids, 16, w1["qweight"].shape[0])
  183. x = x.view(x.shape[0], 1, *x.shape[1:])
  184. pack_factor = self.quant_config.pack_factor
  185. gate_up = ops.awq_group_gemm(x, w1["qweight"], w1["scales"],
  186. w1["qzeros"], topk_weights,
  187. sorted_token_ids, expert_ids,
  188. num_tokens_post_padded, False,
  189. pack_factor)
  190. out = torch.empty((gate_up.shape[:-1] + (gate_up.shape[-1] // 2, )),
  191. dtype=x.dtype,
  192. device=x.device)
  193. _C_ops.silu_and_mul(out, gate_up)
  194. out = ops.awq_group_gemm(out, w2["qweight"], w2["scales"],
  195. w2["qzeros"], topk_weights, sorted_token_ids,
  196. expert_ids, num_tokens_post_padded, True,
  197. pack_factor)
  198. return torch.sum(out, dim=1)