123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176 |
- import importlib.util
- from enum import Enum
- from typing import TYPE_CHECKING, Optional, Union
- import torch
- from loguru import logger
- core_C_available = importlib.util.find_spec('._core_C',
- 'aphrodite') is not None
- # Mirrors enum in `core/scalar_type.hpp`
- class NanRepr(Enum):
- NONE = 0 # nans are not supported
- IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
- EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
- if TYPE_CHECKING or not core_C_available:
- # On platforms were we cannot use/build the C++ core extension (i.e. namely
- # neuron and tpu), we define the mock ScalarType class here that partially
- # mimics the C++ ScalarType class.
- #
- # We also use this provide type signatures to the Python LSP for the methods
- # in the C++ ScalarType class. So these type signatures should be kept
- # in sync with csrc/core/scalar_type.hpp
- from dataclasses import dataclass
- @dataclass(frozen=True)
- class ScalarType:
- """
- ScalarType can represent a wide range of floating point and integer
- types, in particular it can be used to represent sub-byte data types
- (something that torch.dtype currently does not support). It is also
- capable of representing types with a bias, i.e.:
- `stored_value = value + bias`,
- this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
- of 8). The implementation for this class can be found in
- csrc/core/scalar_type.hpp, these type signatures should be kept in sync
- with that file.
- """
- exponent: int
- """
- Number of bits in the exponent if this is a floating point type
- (zero if this an integer type)
- """
- mantissa: int
- """
- Number of bits in the mantissa if this is a floating point type,
- or the number bits representing an integer excluding the sign bit if
- this an integer type.
- """
- bias: int
- """
- bias used to encode the values in this scalar type
- (value = stored_value - bias, default 0) for example if we store the
- type as an unsigned integer with a bias of 128 then the value 0 will be
- stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
- """
- signed: bool
- "If the type is signed (i.e. has a sign bit)"
- _finite_values_only: bool = False
- """
- Private: if NANs are supported, used `has_infs()` instead.
- """
- nan_repr: int = NanRepr.IEEE_754.value
- """
- How NaNs are represent in this scalar type, returns NanRepr value.
- (not applicable for integer types)
- """
- @property
- def size_bits(self):
- return self.exponent + self.mantissa + int(self.signed)
- def min(self) -> Union[int, float]:
- """
- Min representable value for this scalar type.
- (accounting for bias if there is one)
- """
- raise NotImplementedError
- def max(self) -> Union[int, float]:
- """
- Max representable value for this scalar type.
- (accounting for bias if there is one)
- """
- raise NotImplementedError
- def is_signed(self) -> bool:
- """
- If the type is signed (i.e. has a sign bit), same as `signed`
- added for consistency with:
- https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
- """
- ...
- def is_floating_point(self):
- "If the type is a floating point type"
- return self.exponent != 0
- def is_integer(self):
- "If the type is an integer type"
- return self.exponent == 0
- def has_bias(self):
- "If the type has a non-zero bias"
- return self.bias != 0
- def has_infs(self):
- "If the type is floating point and supports infinity"
- return not self._finite_values_only
- def has_nans(self):
- return self.nan_repr != NanRepr.NONE.value
- def is_ieee_754(self) -> bool:
- """
- If the type is a floating point type that follows IEEE 754
- conventions
- """
- return self.nan_repr == NanRepr.IEEE_754.value and \
- not self._finite_values_only
- def __str__(self) -> str:
- raise NotImplementedError
- def __repr__(self) -> str:
- raise NotImplementedError
- #
- # Convenience Constructors
- #
- @classmethod
- def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
- "Create a signed integer scalar type (size_bits includes sign-bit)."
- return cls(size_bits - 1, size_bits, bias if bias else 0, True)
- @classmethod
- def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
- """Create a unsigned integer scalar type."""
- return cls(size_bits, size_bits, bias if bias else 0, False)
- @classmethod
- def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
- """
- Create a standard floating point type
- (i.e. follows IEEE 754 conventions).
- """
- return cls(exponent, mantissa, 0, True)
- @classmethod
- def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
- nan_repr: int):
- """
- Create a non-standard floating point type
- (i.e. does not follow IEEE 754 conventions).
- """
- return cls(exponent, mantissa, 0, True, finite_values_only,
- nan_repr)
- elif core_C_available:
- try:
- import aphrodite._core_C # noqa: F401
- except ImportError as e:
- logger.warning(f"Failed to import from aphrodite._core_C with {e}")
- ScalarType = torch.classes._core_C.ScalarType
|