rocm.py 354 B

123456789101112131415
  1. from functools import lru_cache
  2. from typing import Tuple
  3. import torch
  4. from .interface import Platform, PlatformEnum
  5. class RocmPlatform(Platform):
  6. _enum = PlatformEnum.ROCM
  7. @staticmethod
  8. @lru_cache(maxsize=8)
  9. def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
  10. return torch.cuda.get_device_capability(device_id)