interface.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import enum
  2. from typing import Optional, Tuple
  3. import torch
  4. class PlatformEnum(enum.Enum):
  5. CUDA = enum.auto()
  6. ROCM = enum.auto()
  7. TPU = enum.auto()
  8. CPU = enum.auto()
  9. UNSPECIFIED = enum.auto()
  10. class Platform:
  11. _enum: PlatformEnum
  12. def is_cuda(self) -> bool:
  13. return self._enum == PlatformEnum.CUDA
  14. def is_rocm(self) -> bool:
  15. return self._enum == PlatformEnum.ROCM
  16. def is_tpu(self) -> bool:
  17. return self._enum == PlatformEnum.TPU
  18. def is_cpu(self) -> bool:
  19. return self._enum == PlatformEnum.CPU
  20. @staticmethod
  21. def get_device_capability(device_id: int = 0) -> Optional[Tuple[int, int]]:
  22. return None
  23. @staticmethod
  24. def get_device_name(device_id: int = 0) -> str:
  25. raise NotImplementedError
  26. @staticmethod
  27. def inference_mode():
  28. """A device-specific wrapper of `torch.inference_mode`.
  29. This wrapper is recommended because some hardware backends such as TPU
  30. do not support `torch.inference_mode`. In such a case, they will fall
  31. back to `torch.no_grad` by overriding this method.
  32. """
  33. return torch.inference_mode(mode=True)
  34. class UnspecifiedPlatform(Platform):
  35. _enum = PlatformEnum.UNSPECIFIED