1234567891011121314151617181920 |
- from functools import lru_cache
- from typing import Tuple
- import torch
- from .interface import Platform, PlatformEnum
- class RocmPlatform(Platform):
- _enum = PlatformEnum.ROCM
- @staticmethod
- @lru_cache(maxsize=8)
- def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
- return torch.cuda.get_device_capability(device_id)
- @staticmethod
- @lru_cache(maxsize=8)
- def get_device_name(device_id: int = 0) -> str:
- return torch.cuda.get_device_name(device_id)
|