base_config.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import inspect
  2. from abc import ABC, abstractmethod
  3. from typing import Any, Dict, List, Optional, Type
  4. import torch
  5. from torch import nn
  6. class QuantizeMethodBase(ABC):
  7. """Base class for different quantized methods."""
  8. @abstractmethod
  9. def create_weights(self, layer: torch.nn.Module, *weight_args,
  10. **extra_weight_attrs):
  11. """Create weights for a layer.
  12. The weights will be set as attributes of the layer."""
  13. raise NotImplementedError
  14. @abstractmethod
  15. def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
  16. """Apply the weights in layer to the input tensor.
  17. Expects create_weights to have been called before on the layer."""
  18. raise NotImplementedError
  19. # Not required functions
  20. def embedding(self, layer: torch.nn.Module, *args,
  21. **kwargs) -> torch.Tensor:
  22. """Gather embeddings in the layer based on indices in the input tensor.
  23. Expects create_weights to have been called before on the layer."""
  24. raise NotImplementedError
  25. def process_weights_after_loading(self, layer: nn.Module) -> None:
  26. """Process the weight after loading.
  27. This can be used for example, to transpose weights for computation.
  28. """
  29. return
  30. def method_has_implemented_embedding(
  31. method_class: Type[QuantizeMethodBase]) -> bool:
  32. """
  33. Not all quant methods have embedding implemented, so we need to check that
  34. it exists for our given method. We check this by making sure the function
  35. has been changed from the base implementation.
  36. """
  37. base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding",
  38. None)
  39. class_embedding = inspect.getattr_static(method_class, "embedding", None)
  40. return (class_embedding is not None
  41. and class_embedding is not base_embedding)
  42. class QuantizationConfig(ABC):
  43. """Base class for quantization configs."""
  44. @abstractmethod
  45. def get_name(self) -> str:
  46. """Name of the quantization method."""
  47. raise NotImplementedError
  48. @abstractmethod
  49. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  50. """List of supported activation dtypes."""
  51. raise NotImplementedError
  52. @classmethod
  53. @abstractmethod
  54. def get_min_capability(cls) -> int:
  55. """Minimum GPU capability to support the quantization method.
  56. E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
  57. This requirement is due to the custom CUDA kernels used by the
  58. quantization method.
  59. """
  60. raise NotImplementedError
  61. @staticmethod
  62. @abstractmethod
  63. def get_config_filenames() -> List[str]:
  64. """List of filenames to search for in the model directory."""
  65. raise NotImplementedError
  66. @classmethod
  67. @abstractmethod
  68. def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
  69. """Create a config class from the model's quantization config."""
  70. raise NotImplementedError
  71. @classmethod
  72. def override_quantization_method(cls, hf_quant_cfg,
  73. user_quant) -> Optional[str]:
  74. """
  75. Detects if this quantization method can support a given checkpoint
  76. format by overriding the user specified quantization method --
  77. this method should only be overwritten by subclasses in exceptional
  78. circumstances
  79. """
  80. return None
  81. @staticmethod
  82. def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
  83. """Get a value from the model's quantization config."""
  84. for key in keys:
  85. if key in config:
  86. return config[key]
  87. raise ValueError(f"Cannot find any of {keys} in the model's "
  88. "quantization config.")
  89. @staticmethod
  90. def get_from_keys_or(config: Dict[str, Any], keys: List[str],
  91. default: Any) -> Any:
  92. """Get a optional value from the model's quantization config."""
  93. try:
  94. return QuantizationConfig.get_from_keys(config, keys)
  95. except ValueError:
  96. return default
  97. @abstractmethod
  98. def get_quant_method(self, layer: torch.nn.Module,
  99. prefix: str) -> QuantizeMethodBase:
  100. """Get the quantize method to use for the quantized layer."""
  101. raise NotImplementedError
  102. @abstractmethod
  103. def get_scaled_act_names(self) -> List[str]:
  104. """Returns the activation function names that should be post-scaled.
  105. For now, this is only used by AWQ.
  106. """
  107. raise NotImplementedError