cuda.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. """Code inside this file can safely assume cuda platform, e.g. importing
  2. pynvml. However, it should not initialize cuda context.
  3. """
  4. import os
  5. from functools import lru_cache, wraps
  6. from typing import Tuple
  7. import pynvml
  8. from .interface import Platform, PlatformEnum
  9. def with_nvml_context(fn):
  10. @wraps(fn)
  11. def wrapper(*args, **kwargs):
  12. pynvml.nvmlInit()
  13. try:
  14. return fn(*args, **kwargs)
  15. finally:
  16. pynvml.nvmlShutdown()
  17. return wrapper
  18. @lru_cache(maxsize=8)
  19. @with_nvml_context
  20. def get_physical_device_capability(device_id: int = 0) -> Tuple[int, int]:
  21. handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
  22. return pynvml.nvmlDeviceGetCudaComputeCapability(handle)
  23. def device_id_to_physical_device_id(device_id: int) -> int:
  24. if "CUDA_VISIBLE_DEVICES" in os.environ:
  25. device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
  26. device_ids = [int(device_id) for device_id in device_ids]
  27. physical_device_id = device_ids[device_id]
  28. else:
  29. physical_device_id = device_id
  30. return physical_device_id
  31. class CudaPlatform(Platform):
  32. _enum = PlatformEnum.CUDA
  33. @staticmethod
  34. def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
  35. physical_device_id = device_id_to_physical_device_id(device_id)
  36. return get_physical_device_capability(physical_device_id)