base_config.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from abc import ABC, abstractmethod
  2. from typing import Any, Dict, List
  3. import torch
  4. from torch import nn
  5. class QuantizeMethodBase(ABC):
  6. """Base class for different quantized methods."""
  7. @abstractmethod
  8. def create_weights(self, layer: torch.nn.Module, *weight_args,
  9. **extra_weight_attrs):
  10. """Create weights for a layer.
  11. The weights will be set as attributes of the layer."""
  12. raise NotImplementedError
  13. @abstractmethod
  14. def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
  15. """Apply the weights in layer to the input tensor.
  16. Expects create_weights to have been called before on the layer."""
  17. raise NotImplementedError
  18. def process_weights_after_loading(self, layer: nn.Module) -> None:
  19. """Process the weight after loading.
  20. This can be used for example, to transpose weights for computation.
  21. """
  22. return
  23. class QuantizationConfig(ABC):
  24. """Base class for quantization configs."""
  25. @abstractmethod
  26. def get_name(self) -> str:
  27. """Name of the quantization method."""
  28. raise NotImplementedError
  29. @abstractmethod
  30. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  31. """List of supported activation dtypes."""
  32. raise NotImplementedError
  33. @abstractmethod
  34. def get_min_capability(self) -> int:
  35. """Minimum GPU capability to support the quantization method.
  36. E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
  37. This requirement is due to the custom CUDA kernels used by the
  38. quantization method.
  39. """
  40. raise NotImplementedError
  41. @staticmethod
  42. @abstractmethod
  43. def get_config_filenames() -> List[str]:
  44. """List of filenames to search for in the model directory."""
  45. raise NotImplementedError
  46. @classmethod
  47. @abstractmethod
  48. def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
  49. """Create a config class from the model's quantization config."""
  50. raise NotImplementedError
  51. @staticmethod
  52. def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
  53. """Get a value from the model's quantization config."""
  54. for key in keys:
  55. if key in config:
  56. return config[key]
  57. raise ValueError(f"Cannot find any of {keys} in the model's "
  58. "quantization config.")
  59. @abstractmethod
  60. def get_quant_method(self, layer: torch.nn.Module) -> QuantizeMethodBase:
  61. """Get the quantize method to use for the quantized layer."""
  62. raise NotImplementedError
  63. @abstractmethod
  64. def get_scaled_act_names(self) -> List[str]:
  65. """Returns the activation function names that should be post-scaled.
  66. For now, this is only used by AWQ.
  67. """
  68. raise NotImplementedError