__init__.py 920 B

123456789101112131415161718192021222324252627282930313233343536
  1. import torch
  2. from .interface import Platform, PlatformEnum, UnspecifiedPlatform
  3. current_platform: Platform
  4. try:
  5. import libtpu
  6. except ImportError:
  7. libtpu = None
  8. is_cpu = False
  9. try:
  10. from importlib.metadata import version
  11. is_cpu = "cpu" in version("aphrodite-engine")
  12. except Exception:
  13. pass
  14. if libtpu is not None:
  15. # people might install pytorch built with cuda but run on tpu
  16. # so we need to check tpu first
  17. from .tpu import TpuPlatform
  18. current_platform = TpuPlatform()
  19. elif torch.version.cuda is not None:
  20. from .cuda import CudaPlatform
  21. current_platform = CudaPlatform()
  22. elif torch.version.hip is not None:
  23. from .rocm import RocmPlatform
  24. current_platform = RocmPlatform()
  25. elif is_cpu:
  26. from .cpu import CpuPlatform
  27. current_platform = CpuPlatform()
  28. else:
  29. current_platform = UnspecifiedPlatform()
  30. __all__ = ['Platform', 'PlatformEnum', 'current_platform']