rocm.py 504 B

1234567891011121314151617181920
  1. from functools import lru_cache
  2. from typing import Tuple
  3. import torch
  4. from .interface import Platform, PlatformEnum
  5. class RocmPlatform(Platform):
  6. _enum = PlatformEnum.ROCM
  7. @staticmethod
  8. @lru_cache(maxsize=8)
  9. def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
  10. return torch.cuda.get_device_capability(device_id)
  11. @staticmethod
  12. @lru_cache(maxsize=8)
  13. def get_device_name(device_id: int = 0) -> str:
  14. return torch.cuda.get_device_name(device_id)