_core_ext.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. import importlib.util
  2. from enum import Enum
  3. from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
  4. import torch
  5. from loguru import logger
  6. core_C_available = importlib.util.find_spec('._core_C',
  7. 'aphrodite') is not None
  8. # Mirrors enum in `core/scalar_type.hpp`
  9. class NanRepr(Enum):
  10. NONE = 0 # nans are not supported
  11. IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
  12. EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
  13. if TYPE_CHECKING or not core_C_available:
  14. # On platforms were we cannot use/build the C++ core extension (i.e. namely
  15. # neuron and tpu), we define the mock ScalarType class here that partially
  16. # mimics the C++ ScalarType class.
  17. #
  18. # We also use this provide type signatures to the Python LSP for the methods
  19. # in the C++ ScalarType class. So these type signatures should be kept
  20. # in sync with kernels/core/scalar_type.hpp
  21. from dataclasses import dataclass
  22. @dataclass(frozen=True)
  23. class ScalarType:
  24. """
  25. ScalarType can represent a wide range of floating point and integer
  26. types, in particular it can be used to represent sub-byte data types
  27. (something that torch.dtype currently does not support). It is also
  28. capable of representing types with a bias, i.e.:
  29. `stored_value = value + bias`,
  30. this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
  31. of 8). The implementation for this class can be found in
  32. kernels/core/scalar_type.hpp, these type signatures should be kept in
  33. sync with that file.
  34. """
  35. exponent: int
  36. """
  37. Number of bits in the exponent if this is a floating point type
  38. (zero if this an integer type)
  39. """
  40. mantissa: int
  41. """
  42. Number of bits in the mantissa if this is a floating point type,
  43. or the number bits representing an integer excluding the sign bit if
  44. this an integer type.
  45. """
  46. bias: int
  47. """
  48. bias used to encode the values in this scalar type
  49. (value = stored_value - bias, default 0) for example if we store the
  50. type as an unsigned integer with a bias of 128 then the value 0 will be
  51. stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
  52. """
  53. signed: bool
  54. "If the type is signed (i.e. has a sign bit)"
  55. _finite_values_only: bool = False
  56. """
  57. Private: if NANs are supported, used `has_infs()` instead.
  58. """
  59. nan_repr: int = NanRepr.IEEE_754.value
  60. """
  61. How NaNs are represent in this scalar type, returns NanRepr value.
  62. (not applicable for integer types)
  63. """
  64. @property
  65. def size_bits(self):
  66. return self.exponent + self.mantissa + int(self.signed)
  67. def min(self) -> Union[int, float]:
  68. """
  69. Min representable value for this scalar type.
  70. (accounting for bias if there is one)
  71. """
  72. raise NotImplementedError
  73. def max(self) -> Union[int, float]:
  74. """
  75. Max representable value for this scalar type.
  76. (accounting for bias if there is one)
  77. """
  78. raise NotImplementedError
  79. def is_signed(self) -> bool:
  80. """
  81. If the type is signed (i.e. has a sign bit), same as `signed`
  82. added for consistency with:
  83. https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
  84. """
  85. ...
  86. def is_floating_point(self) -> bool:
  87. "If the type is a floating point type"
  88. return self.exponent != 0
  89. def is_integer(self) -> bool:
  90. "If the type is an integer type"
  91. return self.exponent == 0
  92. def has_bias(self) -> bool:
  93. "If the type has a non-zero bias"
  94. return self.bias != 0
  95. def has_infs(self) -> bool:
  96. "If the type is floating point and supports infinity"
  97. return not self._finite_values_only
  98. def has_nans(self) -> bool:
  99. return self.nan_repr != NanRepr.NONE.value
  100. def is_ieee_754(self) -> bool:
  101. """
  102. If the type is a floating point type that follows IEEE 754
  103. conventions
  104. """
  105. return self.nan_repr == NanRepr.IEEE_754.value and \
  106. not self._finite_values_only
  107. def __str__(self) -> str:
  108. raise NotImplementedError
  109. def __repr__(self) -> str:
  110. raise NotImplementedError
  111. # __len__ needs to be defined (and has to throw TypeError) for pytorch's
  112. # opcheck to work.
  113. def __len__(self) -> int:
  114. raise TypeError
  115. #
  116. # Convenience Constructors
  117. #
  118. @classmethod
  119. def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
  120. "Create a signed integer scalar type (size_bits includes sign-bit)."
  121. return cls(size_bits - 1, size_bits, bias if bias else 0, True)
  122. @classmethod
  123. def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
  124. """Create a unsigned integer scalar type."""
  125. return cls(size_bits, size_bits, bias if bias else 0, False)
  126. @classmethod
  127. def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
  128. """
  129. Create a standard floating point type
  130. (i.e. follows IEEE 754 conventions).
  131. """
  132. return cls(exponent, mantissa, 0, True)
  133. @classmethod
  134. def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
  135. nan_repr: int) -> 'ScalarType':
  136. """
  137. Create a non-standard floating point type
  138. (i.e. does not follow IEEE 754 conventions).
  139. """
  140. return cls(exponent, mantissa, 0, True, finite_values_only,
  141. nan_repr)
  142. elif core_C_available:
  143. try:
  144. import aphrodite._core_C # noqa: F401
  145. except ImportError as e:
  146. logger.warning(f"Failed to import from aphrodite._core_C with {e}")
  147. ScalarType = torch.classes._core_C.ScalarType
  148. if (hasattr(torch, "_library")
  149. and hasattr(torch._library, "register_fake_class")):
  150. # Needed for dynamo support of ScalarType.
  151. @torch._library.register_fake_class("_core_C::ScalarType")
  152. class FakeScalarType:
  153. def __init__(self, scalar_type):
  154. self.ScalarType = scalar_type
  155. def bias_getter(self) -> int:
  156. return self.ScalarType.bias
  157. def exponent_getter(self) -> int:
  158. return self.ScalarType.exponent
  159. def mantissa_getter(self) -> int:
  160. return self.ScalarType.mantissa
  161. def signed_getter(self) -> bool:
  162. return self.ScalarType.signed
  163. def size_bits_getter(self) -> int:
  164. return self.ScalarType.size_bits
  165. @property
  166. def size_bits(self) -> int:
  167. return self.ScalarType.size_bits
  168. def min(self) -> Union[int, float]:
  169. return self.ScalarType.min()
  170. def max(self) -> Union[int, float]:
  171. return self.ScalarType.max()
  172. def is_signed(self) -> bool:
  173. return self.ScalarType.is_signed()
  174. def is_floating_point(self) -> bool:
  175. return self.ScalarType.is_floating_point()
  176. def is_integer(self) -> bool:
  177. return self.ScalarType.is_integer()
  178. def has_bias(self) -> bool:
  179. return self.ScalarType.has_bias()
  180. def has_infs(self) -> bool:
  181. return self.ScalarType.has_infs()
  182. def has_nans(self) -> bool:
  183. return self.ScalarType.has_nans()
  184. def is_ieee_754(self) -> bool:
  185. return self.ScalarType.is_ieee_754()
  186. def __str__(self) -> str:
  187. return self.ScalarType.__str__()
  188. def __repr__(self) -> str:
  189. return self.ScalarType.__repr__()
  190. def __len__(self) -> int:
  191. return self.ScalarType.__len__()
  192. def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]:
  193. return torch.classes._core_C.ScalarType.__obj_flatten__(
  194. self.ScalarType)
  195. @classmethod
  196. def __obj_unflatten__(
  197. cls, flat_type: Tuple[Tuple[str, Any],
  198. ...]) -> 'ScalarType':
  199. return cls(
  200. torch.classes._core_C.ScalarType.__obj_unflatten__(
  201. flat_type))
  202. @classmethod
  203. def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
  204. return ScalarType.int_(size_bits, bias)
  205. @classmethod
  206. def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
  207. return ScalarType.uint(size_bits, bias)
  208. @classmethod
  209. def float_IEEE754(cls, exponent: int,
  210. mantissa: int) -> 'ScalarType':
  211. return ScalarType.float_IEEE754(exponent, mantissa)
  212. @classmethod
  213. def float_(cls, exponent: int, mantissa: int,
  214. finite_values_only: bool,
  215. nan_repr: int) -> 'ScalarType':
  216. return ScalarType.float_(exponent, mantissa,
  217. finite_values_only, nan_repr)