cuda.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. """Code inside this file can safely assume cuda platform, e.g. importing
  2. pynvml. However, it should not initialize cuda context.
  3. """
  4. import os
  5. from functools import lru_cache, wraps
  6. from typing import List, Tuple
  7. import pynvml
  8. from loguru import logger
  9. from .interface import Platform, PlatformEnum
  10. # NVML utils
  11. # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
  12. # all the related functions work on real physical device ids.
  13. # the major benefit of using NVML is that it will not initialize CUDA
  14. def with_nvml_context(fn):
  15. @wraps(fn)
  16. def wrapper(*args, **kwargs):
  17. pynvml.nvmlInit()
  18. try:
  19. return fn(*args, **kwargs)
  20. finally:
  21. pynvml.nvmlShutdown()
  22. return wrapper
  23. @lru_cache(maxsize=8)
  24. @with_nvml_context
  25. def get_physical_device_capability(device_id: int = 0) -> Tuple[int, int]:
  26. handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
  27. return pynvml.nvmlDeviceGetCudaComputeCapability(handle)
  28. @lru_cache(maxsize=8)
  29. @with_nvml_context
  30. def get_physical_device_name(device_id: int = 0) -> str:
  31. handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
  32. return pynvml.nvmlDeviceGetName(handle)
  33. @with_nvml_context
  34. def warn_if_different_devices():
  35. device_ids: int = pynvml.nvmlDeviceGetCount()
  36. if device_ids > 1:
  37. device_names = [get_physical_device_name(i) for i in range(device_ids)]
  38. if len(set(device_names)) > 1 and os.environ.get(
  39. "CUDA_DEVICE_ORDER") != "PCI_BUS_ID":
  40. logger.warning(
  41. f"Detected different devices in the system: \n{device_names}\n"
  42. "Please make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
  43. "avoid unexpected behavior.")
  44. def device_id_to_physical_device_id(device_id: int) -> int:
  45. if "CUDA_VISIBLE_DEVICES" in os.environ:
  46. device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
  47. device_ids = [int(device_id) for device_id in device_ids]
  48. physical_device_id = device_ids[device_id]
  49. else:
  50. physical_device_id = device_id
  51. return physical_device_id
  52. class CudaPlatform(Platform):
  53. _enum = PlatformEnum.CUDA
  54. @staticmethod
  55. def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
  56. physical_device_id = device_id_to_physical_device_id(device_id)
  57. return get_physical_device_capability(physical_device_id)
  58. @staticmethod
  59. def get_device_name(device_id: int = 0) -> str:
  60. physical_device_id = device_id_to_physical_device_id(device_id)
  61. return get_physical_device_name(physical_device_id)
  62. @staticmethod
  63. @with_nvml_context
  64. def is_full_nvlink(physical_device_ids: List[int]) -> bool:
  65. """
  66. query if the set of gpus are fully connected by nvlink (1 hop)
  67. """
  68. handles = [
  69. pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
  70. ]
  71. for i, handle in enumerate(handles):
  72. for j, peer_handle in enumerate(handles):
  73. if i < j:
  74. try:
  75. p2p_status = pynvml.nvmlDeviceGetP2PStatus(
  76. handle, peer_handle,
  77. pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
  78. if p2p_status != pynvml.NVML_P2P_STATUS_OK:
  79. return False
  80. except pynvml.NVMLError as error:
  81. logger.error(
  82. "NVLink detection failed. This is normal if your"
  83. " machine has no NVLink equipped.",
  84. exc_info=error)
  85. return False
  86. return True