awq.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from typing import Any, Dict, List
  2. import torch
  3. from aphrodite.modeling.quantization_utils.base import QuantizationConfig
  4. class AWQConfig(QuantizationConfig):
  5. """Config class for AWQ.
  6. Reference: https://arxiv.org/abs/2306.00978
  7. """
  8. def __init__(
  9. self,
  10. weight_bits: int,
  11. group_size: int,
  12. zero_point: bool,
  13. ) -> None:
  14. self.weight_bits = weight_bits
  15. self.group_size = group_size
  16. self.zero_point = zero_point
  17. if self.weight_bits != 4:
  18. raise ValueError(
  19. "Currently, only 4-bit weight quantization is supported for "
  20. f"AWQ, but got {self.weight_bits} bits.")
  21. self.pack_factor = 32 // self.weight_bits
  22. def __repr__(self) -> str:
  23. return (f"AWQConfig(weight_bits={self.weight_bits}, "
  24. f"group_size={self.group_size}, "
  25. f"zero_point={self.zero_point})")
  26. @classmethod
  27. def get_name(cls) -> str:
  28. return "awq"
  29. @classmethod
  30. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  31. return [torch.half]
  32. @classmethod
  33. def get_min_capability(cls) -> int:
  34. # The AWQ kernel only supports Turing or newer GPUs.
  35. return 75
  36. @classmethod
  37. def get_config_filenames(cls) -> List[str]:
  38. return [
  39. "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
  40. "quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long
  41. ]
  42. @classmethod
  43. def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
  44. weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
  45. group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
  46. zero_point = cls.get_from_keys(config, ["zero_point"])
  47. return cls(weight_bits, group_size, zero_point)
  48. @classmethod
  49. def get_packed_tensor_names(cls) -> List[str]:
  50. return ["qweight", "qzeros"]
  51. @classmethod
  52. def get_transposed_tensor_names(cls) -> List[str]:
  53. return ["qweight", "qzeros", "scales"]
  54. def get_row_tp_tensor_names(self) -> List[str]:
  55. return ["qweight", "qzeros", "scales"]
  56. def get_column_tp_tensor_names(self) -> List[str]:
  57. return ["qweight", "qzeros", "scales", "bias"]