cuda.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. """Code inside this file can safely assume cuda platform, e.g. importing
  2. pynvml. However, it should not initialize cuda context.
  3. """
  4. import os
  5. from functools import lru_cache, wraps
  6. from typing import List, Tuple
  7. import pynvml
  8. from loguru import logger
  9. from .interface import Platform, PlatformEnum
  10. # NVML utils
  11. # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
  12. # all the related functions work on real physical device ids.
  13. # the major benefit of using NVML is that it will not initialize CUDA
  14. def with_nvml_context(fn):
  15. @wraps(fn)
  16. def wrapper(*args, **kwargs):
  17. pynvml.nvmlInit()
  18. try:
  19. return fn(*args, **kwargs)
  20. finally:
  21. pynvml.nvmlShutdown()
  22. return wrapper
  23. @lru_cache(maxsize=8)
  24. @with_nvml_context
  25. def get_physical_device_capability(device_id: int = 0) -> Tuple[int, int]:
  26. handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
  27. return pynvml.nvmlDeviceGetCudaComputeCapability(handle)
  28. def device_id_to_physical_device_id(device_id: int) -> int:
  29. if "CUDA_VISIBLE_DEVICES" in os.environ:
  30. device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
  31. device_ids = [int(device_id) for device_id in device_ids]
  32. physical_device_id = device_ids[device_id]
  33. else:
  34. physical_device_id = device_id
  35. return physical_device_id
  36. class CudaPlatform(Platform):
  37. _enum = PlatformEnum.CUDA
  38. @staticmethod
  39. def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
  40. physical_device_id = device_id_to_physical_device_id(device_id)
  41. return get_physical_device_capability(physical_device_id)
  42. @staticmethod
  43. @with_nvml_context
  44. def is_full_nvlink(physical_device_ids: List[int]) -> bool:
  45. """
  46. query if the set of gpus are fully connected by nvlink (1 hop)
  47. """
  48. handles = [
  49. pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
  50. ]
  51. for i, handle in enumerate(handles):
  52. for j, peer_handle in enumerate(handles):
  53. if i < j:
  54. try:
  55. p2p_status = pynvml.nvmlDeviceGetP2PStatus(
  56. handle, peer_handle,
  57. pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
  58. if p2p_status != pynvml.NVML_P2P_STATUS_OK:
  59. return False
  60. except pynvml.NVMLError as error:
  61. logger.error(
  62. "NVLink detection failed. This is normal if your"
  63. " machine has no NVLink equipped.",
  64. exc_info=error)
  65. return False
  66. return True