1
0

base_config.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from abc import ABC, abstractmethod
  2. from typing import Any, Dict, List
  3. import torch
  4. from aphrodite.modeling.layers.linear import LinearMethodBase
  5. class QuantizationConfig(ABC):
  6. """Base class for quantization configs."""
  7. @abstractmethod
  8. def get_name(self) -> str:
  9. """Name of the quantization method."""
  10. raise NotImplementedError
  11. @abstractmethod
  12. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  13. """List of supported activation dtypes."""
  14. raise NotImplementedError
  15. @abstractmethod
  16. def get_min_capability(self) -> int:
  17. """Minimum GPU capability to support the quantization method.
  18. E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
  19. This requirement is due to the custom CUDA kernels used by the
  20. quantization method.
  21. """
  22. raise NotImplementedError
  23. @staticmethod
  24. @abstractmethod
  25. def get_config_filenames() -> List[str]:
  26. """List of filenames to search for in the model directory."""
  27. raise NotImplementedError
  28. @classmethod
  29. @abstractmethod
  30. def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
  31. """Create a config class from the model's quantization config."""
  32. raise NotImplementedError
  33. @staticmethod
  34. def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
  35. """Get a value from the model's quantization config."""
  36. for key in keys:
  37. if key in config:
  38. return config[key]
  39. raise ValueError(f"Cannot find any of {keys} in the model's "
  40. "quantization config.")
  41. @abstractmethod
  42. def get_linear_method(self) -> LinearMethodBase:
  43. """Get the linear method to use for the quantized linear layer."""
  44. raise NotImplementedError
  45. @abstractmethod
  46. def get_scaled_act_names(self) -> List[str]:
  47. """Returns the activation function names that should be post-scaled.
  48. For now, this is only used by AWQ.
  49. """
  50. raise NotImplementedError