gptq.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from typing import Any, Dict, List
  2. import torch
  3. from aphrodite.modeling.quantization_utils.base import QuantizationConfig
  4. class GPTQConfig(QuantizationConfig):
  5. def __init__(
  6. self,
  7. weight_bits: int,
  8. group_size: int,
  9. desc_act: bool,
  10. ) -> None:
  11. self.weight_bits = weight_bits
  12. self.group_size = group_size
  13. self.desc_act = desc_act
  14. self.pack_factor = 32 // self.weight_bits
  15. if self.weight_bits != 4:
  16. raise ValueError(
  17. f"Currently only 4-bit quant is supported for GPTQ, you passed {self.weight_bits} bits."
  18. )
  19. def __repr__(self) -> str:
  20. return (f"GPTQConfig(weight_bits={self.weight_bits}), "
  21. f"group_size={self.group_size}, "
  22. f"desc_act={self.desc_act}")
  23. @classmethod
  24. def get_name(cls) -> str:
  25. return "gptq"
  26. @classmethod
  27. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  28. return [torch.half]
  29. @classmethod
  30. def get_min_capability(cls) -> int:
  31. return 60
  32. @classmethod
  33. def get_config_filenames(cls) -> List[str]:
  34. return [
  35. "quantize_config.json",
  36. ]
  37. @classmethod
  38. def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
  39. weight_bits = cls.get_from_keys(config, ["bits"])
  40. group_size = cls.get_from_keys(config, ["group_size"])
  41. desc_act = cls.get_from_keys(config, ["desc_act"])
  42. return cls(weight_bits, group_size, desc_act)
  43. @classmethod
  44. def get_packed_tensor_names(cls) -> List[str]:
  45. return ["qzeros"]
  46. @classmethod
  47. def get_transposed_tensor_names(cls) -> List[str]:
  48. return ["qweight", "qzeros", "scales"]
  49. def get_row_tp_tensor_names(self) -> List[str]:
  50. if self.desc_act and self.group_size != -1:
  51. return ["qweight", "g_idx"]
  52. if self.group_size == -1:
  53. return ["qweight"]
  54. return ["qweight", "qzeros", "scales"]
  55. def get_column_tp_tensor_names(self) -> List[str]:
  56. return ["qweight", "qzeros", "scales", "bias"]
  57. def get_ignore_tensor_names(self) -> List[str]:
  58. if self.desc_act and self.group_size != -1:
  59. return []
  60. return ["g_idx"]