allclose_default.py 512 B

123456789101112131415161718
  1. import torch
  2. # Reference default values of atol and rtol are from
  3. # https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
  4. default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
  5. default_rtol = {
  6. torch.float16: 1e-3,
  7. torch.bfloat16: 1.6e-2,
  8. torch.float: 1.3e-6
  9. }
  10. def get_default_atol(output) -> float:
  11. return default_atol[output.dtype]
  12. def get_default_rtol(output) -> float:
  13. return default_rtol[output.dtype]