123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277 |
- import importlib.util
- from enum import Enum
- from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
- import torch
- from loguru import logger
- core_C_available = importlib.util.find_spec('._core_C',
- 'aphrodite') is not None
- # Mirrors enum in `core/scalar_type.hpp`
- class NanRepr(Enum):
- NONE = 0 # nans are not supported
- IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
- EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
- if TYPE_CHECKING or not core_C_available:
- # On platforms were we cannot use/build the C++ core extension (i.e. namely
- # neuron and tpu), we define the mock ScalarType class here that partially
- # mimics the C++ ScalarType class.
- #
- # We also use this provide type signatures to the Python LSP for the methods
- # in the C++ ScalarType class. So these type signatures should be kept
- # in sync with kernels/core/scalar_type.hpp
- from dataclasses import dataclass
- @dataclass(frozen=True)
- class ScalarType:
- """
- ScalarType can represent a wide range of floating point and integer
- types, in particular it can be used to represent sub-byte data types
- (something that torch.dtype currently does not support). It is also
- capable of representing types with a bias, i.e.:
- `stored_value = value + bias`,
- this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
- of 8). The implementation for this class can be found in
- kernels/core/scalar_type.hpp, these type signatures should be kept in
- sync with that file.
- """
- exponent: int
- """
- Number of bits in the exponent if this is a floating point type
- (zero if this an integer type)
- """
- mantissa: int
- """
- Number of bits in the mantissa if this is a floating point type,
- or the number bits representing an integer excluding the sign bit if
- this an integer type.
- """
- bias: int
- """
- bias used to encode the values in this scalar type
- (value = stored_value - bias, default 0) for example if we store the
- type as an unsigned integer with a bias of 128 then the value 0 will be
- stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
- """
- signed: bool
- "If the type is signed (i.e. has a sign bit)"
- _finite_values_only: bool = False
- """
- Private: if NANs are supported, used `has_infs()` instead.
- """
- nan_repr: int = NanRepr.IEEE_754.value
- """
- How NaNs are represent in this scalar type, returns NanRepr value.
- (not applicable for integer types)
- """
- @property
- def size_bits(self):
- return self.exponent + self.mantissa + int(self.signed)
- def min(self) -> Union[int, float]:
- """
- Min representable value for this scalar type.
- (accounting for bias if there is one)
- """
- raise NotImplementedError
- def max(self) -> Union[int, float]:
- """
- Max representable value for this scalar type.
- (accounting for bias if there is one)
- """
- raise NotImplementedError
- def is_signed(self) -> bool:
- """
- If the type is signed (i.e. has a sign bit), same as `signed`
- added for consistency with:
- https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
- """
- ...
- def is_floating_point(self) -> bool:
- "If the type is a floating point type"
- return self.exponent != 0
- def is_integer(self) -> bool:
- "If the type is an integer type"
- return self.exponent == 0
- def has_bias(self) -> bool:
- "If the type has a non-zero bias"
- return self.bias != 0
- def has_infs(self) -> bool:
- "If the type is floating point and supports infinity"
- return not self._finite_values_only
- def has_nans(self) -> bool:
- return self.nan_repr != NanRepr.NONE.value
- def is_ieee_754(self) -> bool:
- """
- If the type is a floating point type that follows IEEE 754
- conventions
- """
- return self.nan_repr == NanRepr.IEEE_754.value and \
- not self._finite_values_only
- def __str__(self) -> str:
- raise NotImplementedError
- def __repr__(self) -> str:
- raise NotImplementedError
- # __len__ needs to be defined (and has to throw TypeError) for pytorch's
- # opcheck to work.
- def __len__(self) -> int:
- raise TypeError
- #
- # Convenience Constructors
- #
- @classmethod
- def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
- "Create a signed integer scalar type (size_bits includes sign-bit)."
- return cls(size_bits - 1, size_bits, bias if bias else 0, True)
- @classmethod
- def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
- """Create a unsigned integer scalar type."""
- return cls(size_bits, size_bits, bias if bias else 0, False)
- @classmethod
- def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
- """
- Create a standard floating point type
- (i.e. follows IEEE 754 conventions).
- """
- return cls(exponent, mantissa, 0, True)
- @classmethod
- def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
- nan_repr: int) -> 'ScalarType':
- """
- Create a non-standard floating point type
- (i.e. does not follow IEEE 754 conventions).
- """
- return cls(exponent, mantissa, 0, True, finite_values_only,
- nan_repr)
- elif core_C_available:
- try:
- import aphrodite._core_C # noqa: F401
- except ImportError as e:
- logger.warning(f"Failed to import from aphrodite._core_C with {e}")
- ScalarType = torch.classes._core_C.ScalarType
- if (hasattr(torch, "_library")
- and hasattr(torch._library, "register_fake_class")):
- # Needed for dynamo support of ScalarType.
- @torch._library.register_fake_class("_core_C::ScalarType")
- class FakeScalarType:
- def __init__(self, scalar_type):
- self.ScalarType = scalar_type
- def bias_getter(self) -> int:
- return self.ScalarType.bias
- def exponent_getter(self) -> int:
- return self.ScalarType.exponent
- def mantissa_getter(self) -> int:
- return self.ScalarType.mantissa
- def signed_getter(self) -> bool:
- return self.ScalarType.signed
- def size_bits_getter(self) -> int:
- return self.ScalarType.size_bits
- @property
- def size_bits(self) -> int:
- return self.ScalarType.size_bits
- def min(self) -> Union[int, float]:
- return self.ScalarType.min()
- def max(self) -> Union[int, float]:
- return self.ScalarType.max()
- def is_signed(self) -> bool:
- return self.ScalarType.is_signed()
- def is_floating_point(self) -> bool:
- return self.ScalarType.is_floating_point()
- def is_integer(self) -> bool:
- return self.ScalarType.is_integer()
- def has_bias(self) -> bool:
- return self.ScalarType.has_bias()
- def has_infs(self) -> bool:
- return self.ScalarType.has_infs()
- def has_nans(self) -> bool:
- return self.ScalarType.has_nans()
- def is_ieee_754(self) -> bool:
- return self.ScalarType.is_ieee_754()
- def __str__(self) -> str:
- return self.ScalarType.__str__()
- def __repr__(self) -> str:
- return self.ScalarType.__repr__()
- def __len__(self) -> int:
- return self.ScalarType.__len__()
- def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]:
- return torch.classes._core_C.ScalarType.__obj_flatten__(
- self.ScalarType)
- @classmethod
- def __obj_unflatten__(
- cls, flat_type: Tuple[Tuple[str, Any],
- ...]) -> 'ScalarType':
- return cls(
- torch.classes._core_C.ScalarType.__obj_unflatten__(
- flat_type))
- @classmethod
- def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
- return ScalarType.int_(size_bits, bias)
- @classmethod
- def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
- return ScalarType.uint(size_bits, bias)
- @classmethod
- def float_IEEE754(cls, exponent: int,
- mantissa: int) -> 'ScalarType':
- return ScalarType.float_IEEE754(exponent, mantissa)
- @classmethod
- def float_(cls, exponent: int, mantissa: int,
- finite_values_only: bool,
- nan_repr: int) -> 'ScalarType':
- return ScalarType.float_(exponent, mantissa,
- finite_values_only, nan_repr)
|