__init__.py 696 B

1234567891011121314151617181920212223242526
  1. import torch
  2. from .interface import Platform, PlatformEnum, UnspecifiedPlatform
  3. current_platform: Platform
  4. try:
  5. import libtpu
  6. except ImportError:
  7. libtpu = None
  8. if libtpu is not None:
  9. # people might install pytorch built with cuda but run on tpu
  10. # so we need to check tpu first
  11. from .tpu import TpuPlatform
  12. current_platform = TpuPlatform()
  13. elif torch.version.cuda is not None:
  14. from .cuda import CudaPlatform
  15. current_platform = CudaPlatform()
  16. elif torch.version.hip is not None:
  17. from .rocm import RocmPlatform
  18. current_platform = RocmPlatform()
  19. else:
  20. current_platform = UnspecifiedPlatform()
  21. __all__ = ['Platform', 'PlatformEnum', 'current_platform']