scalar_type.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435
  1. from ._core_ext import NanRepr, ScalarType
  2. # naming generally follows: https://github.com/jax-ml/ml_dtypes
  3. # for floating point types (leading f) the scheme is:
  4. # `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
  5. # flags:
  6. # - no-flags: means it follows IEEE 754 conventions
  7. # - f: means finite values only (no infinities)
  8. # - n: means nans are supported (non-standard encoding)
  9. # for integer types the scheme is:
  10. # `[u]int<size_bits>[b<bias>]`
  11. # - if bias is not present it means its zero
  12. class scalar_types:
  13. int4 = ScalarType.int_(4, None)
  14. uint4 = ScalarType.uint(4, None)
  15. int8 = ScalarType.int_(8, None)
  16. uint8 = ScalarType.uint(8, None)
  17. float8_e4m3fn = ScalarType.float_(4, 3, True,
  18. NanRepr.EXTD_RANGE_MAX_MIN.value)
  19. float8_e5m2 = ScalarType.float_IEEE754(5, 2)
  20. float16_e8m7 = ScalarType.float_IEEE754(8, 7)
  21. float16_e5m10 = ScalarType.float_IEEE754(5, 10)
  22. # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
  23. float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value)
  24. # "gptq" types
  25. uint4b8 = ScalarType.uint(4, 8)
  26. uint8b128 = ScalarType.uint(8, 128)
  27. # colloquial names
  28. bfloat16 = float16_e8m7
  29. float16 = float16_e5m10