_core_ext.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import importlib.util
  2. from enum import Enum
  3. from typing import TYPE_CHECKING, Optional, Union
  4. import torch
  5. from loguru import logger
  6. core_C_available = importlib.util.find_spec('._core_C',
  7. 'aphrodite') is not None
  8. # Mirrors enum in `core/scalar_type.hpp`
  9. class NanRepr(Enum):
  10. NONE = 0 # nans are not supported
  11. IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
  12. EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
  13. if TYPE_CHECKING or not core_C_available:
  14. # On platforms were we cannot use/build the C++ core extension (i.e. namely
  15. # neuron and tpu), we define the mock ScalarType class here that partially
  16. # mimics the C++ ScalarType class.
  17. #
  18. # We also use this provide type signatures to the Python LSP for the methods
  19. # in the C++ ScalarType class. So these type signatures should be kept
  20. # in sync with csrc/core/scalar_type.hpp
  21. from dataclasses import dataclass
  22. @dataclass(frozen=True)
  23. class ScalarType:
  24. """
  25. ScalarType can represent a wide range of floating point and integer
  26. types, in particular it can be used to represent sub-byte data types
  27. (something that torch.dtype currently does not support). It is also
  28. capable of representing types with a bias, i.e.:
  29. `stored_value = value + bias`,
  30. this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
  31. of 8). The implementation for this class can be found in
  32. csrc/core/scalar_type.hpp, these type signatures should be kept in sync
  33. with that file.
  34. """
  35. exponent: int
  36. """
  37. Number of bits in the exponent if this is a floating point type
  38. (zero if this an integer type)
  39. """
  40. mantissa: int
  41. """
  42. Number of bits in the mantissa if this is a floating point type,
  43. or the number bits representing an integer excluding the sign bit if
  44. this an integer type.
  45. """
  46. bias: int
  47. """
  48. bias used to encode the values in this scalar type
  49. (value = stored_value - bias, default 0) for example if we store the
  50. type as an unsigned integer with a bias of 128 then the value 0 will be
  51. stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
  52. """
  53. signed: bool
  54. "If the type is signed (i.e. has a sign bit)"
  55. _finite_values_only: bool = False
  56. """
  57. Private: if NANs are supported, used `has_infs()` instead.
  58. """
  59. nan_repr: int = NanRepr.IEEE_754.value
  60. """
  61. How NaNs are represent in this scalar type, returns NanRepr value.
  62. (not applicable for integer types)
  63. """
  64. @property
  65. def size_bits(self):
  66. return self.exponent + self.mantissa + int(self.signed)
  67. def min(self) -> Union[int, float]:
  68. """
  69. Min representable value for this scalar type.
  70. (accounting for bias if there is one)
  71. """
  72. raise NotImplementedError
  73. def max(self) -> Union[int, float]:
  74. """
  75. Max representable value for this scalar type.
  76. (accounting for bias if there is one)
  77. """
  78. raise NotImplementedError
  79. def is_signed(self) -> bool:
  80. """
  81. If the type is signed (i.e. has a sign bit), same as `signed`
  82. added for consistency with:
  83. https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
  84. """
  85. ...
  86. def is_floating_point(self):
  87. "If the type is a floating point type"
  88. return self.exponent != 0
  89. def is_integer(self):
  90. "If the type is an integer type"
  91. return self.exponent == 0
  92. def has_bias(self):
  93. "If the type has a non-zero bias"
  94. return self.bias != 0
  95. def has_infs(self):
  96. "If the type is floating point and supports infinity"
  97. return not self._finite_values_only
  98. def has_nans(self):
  99. return self.nan_repr != NanRepr.NONE.value
  100. def is_ieee_754(self) -> bool:
  101. """
  102. If the type is a floating point type that follows IEEE 754
  103. conventions
  104. """
  105. return self.nan_repr == NanRepr.IEEE_754.value and \
  106. not self._finite_values_only
  107. def __str__(self) -> str:
  108. raise NotImplementedError
  109. def __repr__(self) -> str:
  110. raise NotImplementedError
  111. #
  112. # Convenience Constructors
  113. #
  114. @classmethod
  115. def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
  116. "Create a signed integer scalar type (size_bits includes sign-bit)."
  117. return cls(size_bits - 1, size_bits, bias if bias else 0, True)
  118. @classmethod
  119. def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
  120. """Create a unsigned integer scalar type."""
  121. return cls(size_bits, size_bits, bias if bias else 0, False)
  122. @classmethod
  123. def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
  124. """
  125. Create a standard floating point type
  126. (i.e. follows IEEE 754 conventions).
  127. """
  128. return cls(exponent, mantissa, 0, True)
  129. @classmethod
  130. def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
  131. nan_repr: int):
  132. """
  133. Create a non-standard floating point type
  134. (i.e. does not follow IEEE 754 conventions).
  135. """
  136. return cls(exponent, mantissa, 0, True, finite_values_only,
  137. nan_repr)
  138. elif core_C_available:
  139. try:
  140. import aphrodite._core_C # noqa: F401
  141. except ImportError as e:
  142. logger.warning(f"Failed to import from aphrodite._core_C with {e}")
  143. ScalarType = torch.classes._core_C.ScalarType