1234567891011121314151617181920212223242526 |
- import torch
- from .interface import Platform, PlatformEnum, UnspecifiedPlatform
- current_platform: Platform
- try:
- import libtpu
- except ImportError:
- libtpu = None
- if libtpu is not None:
- # people might install pytorch built with cuda but run on tpu
- # so we need to check tpu first
- from .tpu import TpuPlatform
- current_platform = TpuPlatform()
- elif torch.version.cuda is not None:
- from .cuda import CudaPlatform
- current_platform = CudaPlatform()
- elif torch.version.hip is not None:
- from .rocm import RocmPlatform
- current_platform = RocmPlatform()
- else:
- current_platform = UnspecifiedPlatform()
- __all__ = ['Platform', 'PlatformEnum', 'current_platform']
|