interface.py 1.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import enum
  2. from typing import Tuple
  3. import torch
  4. class PlatformEnum(enum.Enum):
  5. CUDA = enum.auto()
  6. ROCM = enum.auto()
  7. TPU = enum.auto()
  8. UNSPECIFIED = enum.auto()
  9. class Platform:
  10. _enum: PlatformEnum
  11. def is_cuda(self) -> bool:
  12. return self._enum == PlatformEnum.CUDA
  13. def is_rocm(self) -> bool:
  14. return self._enum == PlatformEnum.ROCM
  15. def is_tpu(self) -> bool:
  16. return self._enum == PlatformEnum.TPU
  17. @staticmethod
  18. def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
  19. raise NotImplementedError
  20. @staticmethod
  21. def inference_mode():
  22. """A device-specific wrapper of `torch.inference_mode`.
  23. This wrapper is recommended because some hardware backends such as TPU
  24. do not support `torch.inference_mode`. In such a case, they will fall
  25. back to `torch.no_grad` by overriding this method.
  26. """
  27. return torch.inference_mode(mode=True)
  28. class UnspecifiedPlatform(Platform):
  29. _enum = PlatformEnum.UNSPECIFIED