|
@@ -4,6 +4,8 @@
|
|
|
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
|
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
|
"""Tensor and pipeline parallel groups."""
|
|
|
+import contextlib
|
|
|
+from multiprocessing import resource_tracker, shared_memory
|
|
|
import os
|
|
|
from typing import List, Optional
|
|
|
|
|
@@ -369,3 +371,68 @@ def destroy_model_parallel():
|
|
|
_PP_DEVICE_GROUP = None
|
|
|
global _PP_GLOBAL_RANKS
|
|
|
_PP_GLOBAL_RANKS = None
|
|
|
+
|
|
|
+
|
|
|
+def is_in_the_same_node(pg: ProcessGroup):
|
|
|
+ """
|
|
|
+ This is a collective operation that checks if all processes in the group
|
|
|
+ are in the same node. It tests if all processes are attached to the same
|
|
|
+ memory system (shared access to shared memory).
|
|
|
+ """
|
|
|
+ assert torch.distributed.get_backend(
|
|
|
+ pg) != torch.distributed.Backend.NCCL, (
|
|
|
+ "is_in_the_same_node should be tested with a non-NCCL group.")
|
|
|
+ # local rank inside the group
|
|
|
+ rank = torch.distributed.get_rank(group=pg)
|
|
|
+ world_size = torch.distributed.get_world_size(group=pg)
|
|
|
+
|
|
|
+ # local tensor in each process to store the result
|
|
|
+ is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
|
|
|
+
|
|
|
+ # global ranks of the processes in the group
|
|
|
+ ranks = torch.distributed.get_process_group_ranks(pg)
|
|
|
+
|
|
|
+ magic_message = b"magic_message"
|
|
|
+ shm = None
|
|
|
+
|
|
|
+ try:
|
|
|
+ with contextlib.suppress(OSError):
|
|
|
+ if rank == 0:
|
|
|
+ # create a shared memory segment
|
|
|
+ shm = shared_memory.SharedMemory(create=True, size=128)
|
|
|
+ shm.buf[:len(magic_message)] = magic_message
|
|
|
+ torch.distributed.broadcast_object_list([shm.name],
|
|
|
+ src=ranks[0],
|
|
|
+ group=pg)
|
|
|
+ is_in_the_same_node[0] = 1
|
|
|
+ else:
|
|
|
+ # try to open the shared memory segment
|
|
|
+ recv = [None]
|
|
|
+ torch.distributed.broadcast_object_list(recv,
|
|
|
+ src=ranks[0],
|
|
|
+ group=pg)
|
|
|
+ name = recv[0]
|
|
|
+ shm = shared_memory.SharedMemory(name=name)
|
|
|
+ if shm.buf[:len(magic_message)] == magic_message:
|
|
|
+ is_in_the_same_node[rank] = 1
|
|
|
+ except Exception as e:
|
|
|
+ logger.error("Error ignored in is_in_the_same_node: %s", e)
|
|
|
+ finally:
|
|
|
+ if shm:
|
|
|
+ shm.close()
|
|
|
+
|
|
|
+ torch.distributed.barrier(group=pg)
|
|
|
+
|
|
|
+ # clean up the shared memory segment
|
|
|
+ with contextlib.suppress(OSError):
|
|
|
+ if rank == 0:
|
|
|
+ if shm:
|
|
|
+ shm.unlink()
|
|
|
+ else:
|
|
|
+ if shm:
|
|
|
+ # fix to https://stackoverflow.com/q/62748654/9191338
|
|
|
+ resource_tracker.unregister(
|
|
|
+ shm._name, "shared_memory") # type: ignore[attr-defined]
|
|
|
+ torch.distributed.all_reduce(is_in_the_same_node, group=pg)
|
|
|
+
|
|
|
+ return is_in_the_same_node.sum().item() == world_size
|