123456789101112131415161718192021222324252627282930313233343536 |
- import pytest
- import torch
- from aphrodite.scalar_type import scalar_types
- @pytest.mark.parametrize("type_tuple", (
- (-8, 7, scalar_types.int4),
- (0, 15, scalar_types.uint4),
- (-8, 7, scalar_types.uint4b8),
- (-128, 127, scalar_types.uint8b128),
- (-28., 28., scalar_types.float6_e3m2f),
- (torch.int8, scalar_types.int8),
- (torch.uint8, scalar_types.uint8),
- (torch.float8_e5m2, scalar_types.float8_e5m2),
- (torch.float8_e4m3fn, scalar_types.float8_e4m3fn),
- (torch.bfloat16, scalar_types.float16_e8m7),
- (torch.float16, scalar_types.float16_e5m10),
- ),
- ids=lambda x: str(x))
- def test_scalar_type_min_max(type_tuple):
- print(type_tuple)
- if len(type_tuple) == 3:
- min, max, t = type_tuple
- else:
- torch_type, t = type_tuple
- if torch_type.is_floating_point:
- min = torch.finfo(torch_type).min
- max = torch.finfo(torch_type).max
- else:
- min = torch.iinfo(torch_type).min
- max = torch.iinfo(torch_type).max
- print(t, min, max, t.min(), t.max())
- assert min == t.min()
- assert max == t.max()
|