test_scalartype.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536
  1. import pytest
  2. import torch
  3. from aphrodite.scalar_type import scalar_types
  4. @pytest.mark.parametrize("type_tuple", (
  5. (-8, 7, scalar_types.int4),
  6. (0, 15, scalar_types.uint4),
  7. (-8, 7, scalar_types.uint4b8),
  8. (-128, 127, scalar_types.uint8b128),
  9. (-28., 28., scalar_types.float6_e3m2f),
  10. (torch.int8, scalar_types.int8),
  11. (torch.uint8, scalar_types.uint8),
  12. (torch.float8_e5m2, scalar_types.float8_e5m2),
  13. (torch.float8_e4m3fn, scalar_types.float8_e4m3fn),
  14. (torch.bfloat16, scalar_types.float16_e8m7),
  15. (torch.float16, scalar_types.float16_e5m10),
  16. ),
  17. ids=lambda x: str(x))
  18. def test_scalar_type_min_max(type_tuple):
  19. print(type_tuple)
  20. if len(type_tuple) == 3:
  21. min, max, t = type_tuple
  22. else:
  23. torch_type, t = type_tuple
  24. if torch_type.is_floating_point:
  25. min = torch.finfo(torch_type).min
  26. max = torch.finfo(torch_type).max
  27. else:
  28. min = torch.iinfo(torch_type).min
  29. max = torch.iinfo(torch_type).max
  30. print(t, min, max, t.min(), t.max())
  31. assert min == t.min()
  32. assert max == t.max()