|
@@ -141,8 +141,10 @@ def _is_full_nvlink(rank, world_size):
|
|
for i in range(world_size):
|
|
for i in range(world_size):
|
|
if i != rank:
|
|
if i != rank:
|
|
try:
|
|
try:
|
|
- link_state = pynvml.nvmlDeviceGetNvLinkState(handle, i)
|
|
|
|
- if not link_state:
|
|
|
|
|
|
+ peer_handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
|
|
|
+ p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
|
|
|
+ handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
|
|
|
|
+ if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
|
return False
|
|
return False
|
|
except pynvml.NVMLError as error:
|
|
except pynvml.NVMLError as error:
|
|
logger.info(
|
|
logger.info(
|