__init__.py 438 B

123456789101112131415161718
  1. from typing import Optional
  2. import torch
  3. from .interface import Platform, PlatformEnum
  4. current_platform: Optional[Platform]
  5. if torch.version.cuda is not None:
  6. from .cuda import CudaPlatform
  7. current_platform = CudaPlatform()
  8. elif torch.version.hip is not None:
  9. from .rocm import RocmPlatform
  10. current_platform = RocmPlatform()
  11. else:
  12. current_platform = None
  13. __all__ = ['Platform', 'PlatformEnum', 'current_platform']