awq.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from torch.nn.parameter import Parameter
  4. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  5. set_weight_attrs)
  6. from aphrodite.modeling.layers.quantization.base_config import QuantizationConfig
  7. from aphrodite.common.logger import init_logger
  8. from aphrodite.common.utils import is_hip
  9. logger = init_logger(__name__)
  10. if is_hip():
  11. logger.warning("AWQ is not supported on ROCm.")
  12. else:
  13. from aphrodite._C import ops as quantization_ops
  14. class AWQConfig(QuantizationConfig):
  15. """Config class for AWQ.
  16. Reference: https://arxiv.org/abs/2306.00978
  17. """
  18. def __init__(
  19. self,
  20. weight_bits: int,
  21. group_size: int,
  22. zero_point: bool,
  23. ) -> None:
  24. self.weight_bits = weight_bits
  25. self.group_size = group_size
  26. self.zero_point = zero_point
  27. if self.weight_bits != 4:
  28. raise ValueError(
  29. "Currently, only 4-bit weight quantization is supported for "
  30. f"AWQ, but got {self.weight_bits} bits.")
  31. self.pack_factor = 32 // self.weight_bits
  32. def __repr__(self) -> str:
  33. return (f"AWQConfig(weight_bits={self.weight_bits}, "
  34. f"group_size={self.group_size}, "
  35. f"zero_point={self.zero_point})")
  36. def get_name(self) -> str:
  37. return "awq"
  38. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  39. return [torch.half]
  40. def get_min_capability(self) -> int:
  41. # The AWQ kernel only supports Turing or newer GPUs.
  42. return 75
  43. @staticmethod
  44. def get_config_filenames() -> List[str]:
  45. return [
  46. "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
  47. "quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long
  48. ]
  49. @classmethod
  50. def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
  51. weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
  52. group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
  53. zero_point = cls.get_from_keys(config, ["zero_point"])
  54. return cls(weight_bits, group_size, zero_point)
  55. def get_linear_method(self) -> "AWQLinearMethod":
  56. return AWQLinearMethod(self)
  57. def get_scaled_act_names(self) -> List[str]:
  58. return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
  59. class AWQLinearMethod(LinearMethodBase):
  60. """Linear method for AWQ.
  61. Args:
  62. quant_config: The AWQ quantization config.
  63. """
  64. def __init__(self, quant_config: AWQConfig):
  65. self.quant_config = quant_config
  66. def create_weights(self, input_size_per_partition: int,
  67. output_size_per_partition: int, input_size: int,
  68. output_size: int,
  69. params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
  70. if input_size_per_partition % self.quant_config.group_size != 0:
  71. raise ValueError(
  72. "The input size is not aligned with the quantized "
  73. "weight shape. This can be caused by too large "
  74. "tensor parallel size.")
  75. if output_size_per_partition % self.quant_config.pack_factor != 0:
  76. raise ValueError(
  77. "The output size is not aligned with the quantized "
  78. "weight shape. This can be caused by too large "
  79. "tensor parallel size.")
  80. qweight = Parameter(
  81. torch.empty(
  82. input_size_per_partition,
  83. output_size_per_partition // self.quant_config.pack_factor,
  84. device="cuda",
  85. dtype=torch.int32,
  86. ),
  87. requires_grad=False,
  88. )
  89. set_weight_attrs(
  90. qweight, {
  91. "input_dim": 0,
  92. "output_dim": 1,
  93. "packed_dim": 1,
  94. "pack_factor": self.quant_config.pack_factor,
  95. })
  96. qzeros = Parameter(
  97. torch.empty(
  98. input_size_per_partition // self.quant_config.group_size,
  99. output_size_per_partition // self.quant_config.pack_factor,
  100. device="cuda",
  101. dtype=torch.int32,
  102. ),
  103. requires_grad=False,
  104. )
  105. set_weight_attrs(
  106. qzeros, {
  107. "input_dim": 0,
  108. "output_dim": 1,
  109. "packed_dim": 1,
  110. "pack_factor": self.quant_config.pack_factor,
  111. })
  112. scales = Parameter(
  113. torch.empty(
  114. input_size_per_partition // self.quant_config.group_size,
  115. output_size_per_partition,
  116. device="cuda",
  117. dtype=params_dtype,
  118. ),
  119. requires_grad=False,
  120. )
  121. set_weight_attrs(scales, {
  122. "input_dim": 0,
  123. "output_dim": 1,
  124. })
  125. return {
  126. "qweight": qweight,
  127. "qzeros": qzeros,
  128. "scales": scales,
  129. }
  130. def apply_weights(self,
  131. weights: Dict[str, Any],
  132. x: torch.Tensor,
  133. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  134. qweight = weights["qweight"]
  135. qzeros = weights["qzeros"]
  136. scales = weights["scales"]
  137. pack_factor = self.quant_config.pack_factor
  138. out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
  139. reshaped_x = x.reshape(-1, x.shape[-1])
  140. out = quantization_ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
  141. pack_factor)
  142. if bias is not None:
  143. out = out + bias
  144. return out.reshape(out_shape)