base.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. class QuantizationConfig:
  4. @classmethod
  5. def get_name(cls) -> str:
  6. """Name of the quantization method."""
  7. raise NotImplementedError
  8. @classmethod
  9. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  10. """List of supported activation dtypes."""
  11. raise NotImplementedError
  12. @classmethod
  13. def get_min_capability(cls) -> int:
  14. """Minimum GPU capability to support the quantization method.
  15. E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
  16. This requirement is due to the custom CUDA kernels used by the
  17. quantization method.
  18. """
  19. raise NotImplementedError
  20. @classmethod
  21. def get_config_filenames(cls) -> List[str]:
  22. """List of filenames to search for in the model directory."""
  23. raise NotImplementedError
  24. @classmethod
  25. def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
  26. """Create a config class from the model's quantization config."""
  27. raise NotImplementedError
  28. @staticmethod
  29. def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
  30. """Get a value from the model's quantization config."""
  31. for key in keys:
  32. if key in config:
  33. return config[key]
  34. raise ValueError(f"Cannot find any of {keys} in the model's "
  35. "quantization config.")
  36. @classmethod
  37. def get_packed_tensors(cls) -> Dict[str, int]:
  38. """Returns a dictionary of packed tensor names and their pack dims."""
  39. raise NotImplementedError
  40. @classmethod
  41. def get_packed_dim(cls, tensor_name: str) -> Optional[int]:
  42. """Returns the pack dim of a tensor if it is packed.
  43. A tensor is considered packed if each element in the tensor is a
  44. packed representation of multiple elements in the original tensor.
  45. For example, an INT32 element in the tensor may represent 8 INT4
  46. elements in the original tensor.
  47. If the tensor is not packed, returns None.
  48. """
  49. packed_tensors = cls.get_packed_tensors()
  50. for packed_tensor_name, pack_dim in packed_tensors.items():
  51. if packed_tensor_name in tensor_name:
  52. return pack_dim
  53. return None
  54. @classmethod
  55. def get_transposed_tensor_names(cls) -> List[str]:
  56. raise NotImplementedError
  57. @classmethod
  58. def is_transposed(cls, tensor_name: str) -> bool:
  59. """Returns True if a tensor is transposed relative to nn.Linear.weight.
  60. """
  61. return any(tag in tensor_name
  62. for tag in cls.get_transposed_tensor_names())
  63. def get_col_parallel_tensor_names(self) -> List[str]:
  64. raise NotImplementedError
  65. def get_row_parallel_tensor_names(self) -> List[str]:
  66. raise NotImplementedError