interface.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import enum
  2. from typing import Tuple
  3. import torch
  4. class PlatformEnum(enum.Enum):
  5. CUDA = enum.auto()
  6. ROCM = enum.auto()
  7. TPU = enum.auto()
  8. UNSPECIFIED = enum.auto()
  9. class Platform:
  10. _enum: PlatformEnum
  11. def is_cuda(self) -> bool:
  12. return self._enum == PlatformEnum.CUDA
  13. def is_rocm(self) -> bool:
  14. return self._enum == PlatformEnum.ROCM
  15. def is_tpu(self) -> bool:
  16. return self._enum == PlatformEnum.TPU
  17. @staticmethod
  18. def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
  19. raise NotImplementedError
  20. @staticmethod
  21. def get_device_name(device_id: int = 0) -> str:
  22. raise NotImplementedError
  23. @staticmethod
  24. def inference_mode():
  25. """A device-specific wrapper of `torch.inference_mode`.
  26. This wrapper is recommended because some hardware backends such as TPU
  27. do not support `torch.inference_mode`. In such a case, they will fall
  28. back to `torch.no_grad` by overriding this method.
  29. """
  30. return torch.inference_mode(mode=True)
  31. class UnspecifiedPlatform(Platform):
  32. _enum = PlatformEnum.UNSPECIFIED