gptq.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. from typing import Any, Dict, List
  2. import torch
  3. from aphrodite.modeling.quantization_utils.base import QuantizationConfig
  4. class GPTQConfig(QuantizationConfig):
  5. """Config class for GPTQ.
  6. Reference: https://arxiv.org/abs/2306.00978
  7. """
  8. def __init__(
  9. self,
  10. weight_bits: int,
  11. group_size: int,
  12. desc_act: bool,
  13. ) -> None:
  14. self.weight_bits = weight_bits
  15. self.group_size = group_size
  16. self.desc_act = desc_act
  17. self.pack_factor = 32 // self.weight_bits
  18. # exllama kernel v1 only supports 4 bit
  19. if self.weight_bits != 4:
  20. raise ValueError(
  21. "Currently, only 4-bit weight quantization is supported for "
  22. f"GPTQ, but got {self.weight_bits} bits.")
  23. def __repr__(self) -> str:
  24. return (f"GPTQConfig(weight_bits={self.weight_bits}, "
  25. f"group_size={self.group_size}, "
  26. f"desc_act={self.desc_act})")
  27. @classmethod
  28. def get_name(cls) -> str:
  29. return "gptq"
  30. @classmethod
  31. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  32. return [torch.half]
  33. @classmethod
  34. # Need to figure it out
  35. def get_min_capability(cls) -> int:
  36. return 60
  37. @classmethod
  38. def get_config_filenames(cls) -> List[str]:
  39. return [
  40. "quantize_config.json",
  41. ]
  42. @classmethod
  43. def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
  44. weight_bits = cls.get_from_keys(config, ["bits"])
  45. group_size = cls.get_from_keys(config, ["group_size"])
  46. desc_act = cls.get_from_keys(config, ["desc_act"])
  47. return cls(weight_bits, group_size, desc_act)
  48. @classmethod
  49. def get_packed_tensors(cls) -> Dict[str, int]:
  50. return {"qzeros": 1}
  51. @classmethod
  52. def get_transposed_tensor_names(cls) -> List[str]:
  53. return ["qweight", "qzeros", "scales"]
  54. def get_row_parallel_tensor_names(self) -> List[str]:
  55. if self.desc_act or self.group_size == -1:
  56. return ["qweight", "g_idx"]
  57. return ["qweight", "qzeros", "scales", "g_idx"]
  58. def get_col_parallel_tensor_names(self) -> List[str]:
  59. return ["qweight", "qzeros", "scales", "bias"]