base.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from typing import Any, Dict, List
  2. import torch
  3. class QuantizationConfig:
  4. @classmethod
  5. def get_name(cls) -> str:
  6. """Name of the quantization method."""
  7. raise NotImplementedError
  8. @classmethod
  9. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  10. """List of supported activation dtypes."""
  11. raise NotImplementedError
  12. @classmethod
  13. def get_min_capability(cls) -> int:
  14. """Minimum GPU capability to support the quantization method.
  15. E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
  16. This requirement is due to the custom CUDA kernels used by the
  17. quantization method.
  18. """
  19. raise NotImplementedError
  20. @classmethod
  21. def get_config_filenames(cls) -> List[str]:
  22. """List of filenames to search for in the model directory."""
  23. raise NotImplementedError
  24. @classmethod
  25. def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
  26. """Create a config class from the model's quantization config."""
  27. raise NotImplementedError
  28. @staticmethod
  29. def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
  30. """Get a value from the model's quantization config."""
  31. for key in keys:
  32. if key in config:
  33. return config[key]
  34. raise ValueError(f"Cannot find any of {keys} in the model's "
  35. "quantization config.")
  36. @classmethod
  37. def get_packed_tensor_names(cls) -> List[str]:
  38. raise NotImplementedError
  39. @classmethod
  40. def is_packed(cls, tensor_name: str) -> bool:
  41. """Returns True if a tensor is packed.
  42. A tensor is considered packed if each element in the tensor is a
  43. packed representation of multiple elements in the original tensor.
  44. For example, an INT32 element in the tensor may represent 8 INT4
  45. elements in the original tensor.
  46. """
  47. return any(tag in tensor_name for tag in cls.get_packed_tensor_names())
  48. @classmethod
  49. def get_transposed_tensor_names(cls) -> List[str]:
  50. raise NotImplementedError
  51. @classmethod
  52. def is_transposed(cls, tensor_name: str) -> bool:
  53. """Returns True if a tensor is transposed relative to nn.Linear.weight.
  54. """
  55. return any(tag in tensor_name
  56. for tag in cls.get_transposed_tensor_names())
  57. @classmethod
  58. def get_row_tp_tensor_names(self) -> List[str]:
  59. raise NotImplementedError
  60. def get_column_tp_tensor_names(self) -> List[str]:
  61. raise NotImplementedError
  62. def get_ignore_tensor_names(self) -> List[str]:
  63. return []