__init__.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536
  1. from typing import Type
  2. # ruff: noqa: E501
  3. from aphrodite.modeling.layers.quantization.base_config import QuantizationConfig
  4. from aphrodite.modeling.layers.quantization.aqlm import AQLMConfig
  5. from aphrodite.modeling.layers.quantization.awq import AWQConfig
  6. from aphrodite.modeling.layers.quantization.bitsandbytes import BitsandBytesConfig
  7. from aphrodite.modeling.layers.quantization.exl2 import Exl2Config
  8. from aphrodite.modeling.layers.quantization.gguf import GGUFConfig
  9. from aphrodite.modeling.layers.quantization.gptq import GPTQConfig
  10. from aphrodite.modeling.layers.quantization.quip import QuipConfig
  11. from aphrodite.modeling.layers.quantization.squeezellm import SqueezeLLMConfig
  12. from aphrodite.modeling.layers.quantization.marlin import MarlinConfig
  13. _QUANTIZATION_CONFIG_REGISTRY = {
  14. "aqlm": AQLMConfig,
  15. "awq": AWQConfig,
  16. "bnb": BitsandBytesConfig,
  17. "exl2": Exl2Config,
  18. "gguf": GGUFConfig,
  19. "gptq": GPTQConfig,
  20. "quip": QuipConfig,
  21. "squeezellm": SqueezeLLMConfig,
  22. "marlin": MarlinConfig,
  23. }
  24. def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
  25. if quantization not in _QUANTIZATION_CONFIG_REGISTRY:
  26. raise ValueError(f"Invalid quantization method: {quantization}")
  27. return _QUANTIZATION_CONFIG_REGISTRY[quantization]
  28. __all__ = [
  29. "QuantizationConfig",
  30. "get_quantization_config",
  31. ]