awq.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from aphrodite import _custom_ops as ops
  4. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  5. from aphrodite.modeling.parameter import (GroupQuantScaleParameter,
  6. PackedAphroditeParameter)
  7. from aphrodite.quantization.base_config import QuantizationConfig
  8. class AWQConfig(QuantizationConfig):
  9. """Config class for AWQ.
  10. Reference: https://arxiv.org/abs/2306.00978
  11. """
  12. def __init__(
  13. self,
  14. weight_bits: int,
  15. group_size: int,
  16. zero_point: bool,
  17. ) -> None:
  18. self.weight_bits = weight_bits
  19. self.group_size = group_size
  20. self.zero_point = zero_point
  21. if self.weight_bits != 4:
  22. raise ValueError(
  23. "Currently, only 4-bit weight quantization is supported for "
  24. f"AWQ, but got {self.weight_bits} bits.")
  25. self.pack_factor = 32 // self.weight_bits
  26. def __repr__(self) -> str:
  27. return (f"AWQConfig(weight_bits={self.weight_bits}, "
  28. f"group_size={self.group_size}, "
  29. f"zero_point={self.zero_point})")
  30. def get_name(self) -> str:
  31. return "awq"
  32. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  33. return [torch.half]
  34. def get_min_capability(cls) -> int:
  35. # The AWQ kernel only supports Turing or newer GPUs.
  36. return 75
  37. @staticmethod
  38. def get_config_filenames() -> List[str]:
  39. return [
  40. "quant_config.json",
  41. "quantize_config.json",
  42. ]
  43. @classmethod
  44. def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
  45. weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
  46. group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
  47. zero_point = cls.get_from_keys(config, ["zero_point"])
  48. return cls(weight_bits, group_size, zero_point)
  49. def get_quant_method(self, layer: torch.nn.Module,
  50. prefix: str) -> Optional["AWQLinearMethod"]:
  51. if isinstance(layer, LinearBase):
  52. return AWQLinearMethod(self)
  53. return None
  54. def get_scaled_act_names(self) -> List[str]:
  55. return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
  56. class AWQLinearMethod(LinearMethodBase):
  57. """Linear method for AWQ.
  58. Args:
  59. quant_config: The AWQ quantization config.
  60. """
  61. def __init__(self, quant_config: AWQConfig):
  62. self.quant_config = quant_config
  63. def create_weights(self, layer: torch.nn.Module,
  64. input_size_per_partition: int,
  65. output_partition_sizes: List[int], input_size: int,
  66. output_size: int, params_dtype: torch.dtype,
  67. **extra_weight_attrs):
  68. if input_size_per_partition % self.quant_config.group_size != 0:
  69. raise ValueError(
  70. "The input size is not aligned with the quantized "
  71. "weight shape. This can be caused by too large "
  72. "tensor parallel size.")
  73. output_size_per_partition = sum(output_partition_sizes)
  74. if output_size_per_partition % self.quant_config.pack_factor != 0:
  75. raise ValueError(
  76. "The output size is not aligned with the quantized "
  77. "weight shape. This can be caused by too large "
  78. "tensor parallel size.")
  79. weight_loader = extra_weight_attrs.get("weight_loader")
  80. qweight = PackedAphroditeParameter(
  81. data=torch.empty(
  82. input_size_per_partition,
  83. output_size_per_partition // self.quant_config.pack_factor,
  84. dtype=torch.int32,
  85. ),
  86. input_dim=0,
  87. output_dim=1,
  88. packed_dim=1,
  89. packed_factor=self.quant_config.pack_factor,
  90. weight_loader=weight_loader)
  91. qzeros = PackedAphroditeParameter(
  92. data=torch.empty(
  93. input_size_per_partition // self.quant_config.group_size,
  94. output_size_per_partition // self.quant_config.pack_factor,
  95. dtype=torch.int32,
  96. ),
  97. input_dim=0,
  98. output_dim=1,
  99. packed_dim=1,
  100. packed_factor=self.quant_config.pack_factor,
  101. weight_loader=weight_loader)
  102. scales = GroupQuantScaleParameter(data=torch.empty(
  103. input_size_per_partition // self.quant_config.group_size,
  104. output_size_per_partition,
  105. dtype=params_dtype,
  106. ),
  107. input_dim=0,
  108. output_dim=1,
  109. weight_loader=weight_loader)
  110. layer.register_parameter("qweight", qweight)
  111. layer.register_parameter("qzeros", qzeros)
  112. layer.register_parameter("scales", scales)
  113. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  114. layer.qweight = torch.nn.Parameter(layer.qweight.data,
  115. requires_grad=False)
  116. layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
  117. requires_grad=False)
  118. layer.scales = torch.nn.Parameter(layer.scales.data,
  119. requires_grad=False)
  120. def apply(self,
  121. layer: torch.nn.Module,
  122. x: torch.Tensor,
  123. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  124. qweight = layer.qweight
  125. scales = layer.scales
  126. qzeros = layer.qzeros
  127. pack_factor = self.quant_config.pack_factor
  128. out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
  129. reshaped_x = x.reshape(-1, x.shape[-1])
  130. # num_tokens >= threshold
  131. FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
  132. if FP16_MATMUL_HEURISTIC_CONDITION:
  133. out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
  134. out = torch.matmul(reshaped_x, out)
  135. else:
  136. out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
  137. pack_factor)
  138. if bias is not None:
  139. out.add_(bias)
  140. return out.reshape(out_shape)