Selaa lähdekoodia

chore: check if process is on the same node

AlpinDale 7 kuukautta sitten
vanhempi
commit
270bd333af

+ 8 - 1
aphrodite/distributed/device_communicators/custom_all_reduce.py

@@ -10,7 +10,7 @@ from torch.distributed import ProcessGroup
 from aphrodite.distributed.device_communicators.custom_all_reduce_utils import \
     gpu_p2p_access_check
 from aphrodite.distributed.parallel_state import (
-    get_local_rank, get_tensor_model_parallel_cpu_group)
+    get_local_rank, get_tensor_model_parallel_cpu_group, is_in_the_same_node)
 
 try:
     import pynvml
@@ -105,6 +105,13 @@ class CustomAllreduce:
         assert dist.get_backend(group) != dist.Backend.NCCL, (
             "CustomAllreduce should be attached to a non-NCCL group.")
 
+        if not is_in_the_same_node(group):
+            # No need to initialize custom allreduce for multi-node case.
+            logger.warning(
+                "Custom allreduce is disabled because this process group"
+                " spans across nodes.")
+            return
+
         rank = dist.get_rank(group=self.group)
         world_size = dist.get_world_size(group=self.group)
         if world_size == 1:

+ 67 - 0
aphrodite/distributed/parallel_state.py

@@ -4,6 +4,8 @@
 # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
 """Tensor and pipeline parallel groups."""
+import contextlib
+from multiprocessing import resource_tracker, shared_memory
 import os
 from typing import List, Optional
 
@@ -369,3 +371,68 @@ def destroy_model_parallel():
     _PP_DEVICE_GROUP = None
     global _PP_GLOBAL_RANKS
     _PP_GLOBAL_RANKS = None
+
+
+def is_in_the_same_node(pg: ProcessGroup):
+    """
+    This is a collective operation that checks if all processes in the group
+    are in the same node. It tests if all processes are attached to the same
+    memory system (shared access to shared memory).
+    """
+    assert torch.distributed.get_backend(
+        pg) != torch.distributed.Backend.NCCL, (
+            "is_in_the_same_node should be tested with a non-NCCL group.")
+    # local rank inside the group
+    rank = torch.distributed.get_rank(group=pg)
+    world_size = torch.distributed.get_world_size(group=pg)
+
+    # local tensor in each process to store the result
+    is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
+
+    # global ranks of the processes in the group
+    ranks = torch.distributed.get_process_group_ranks(pg)
+
+    magic_message = b"magic_message"
+    shm = None
+
+    try:
+        with contextlib.suppress(OSError):
+            if rank == 0:
+                # create a shared memory segment
+                shm = shared_memory.SharedMemory(create=True, size=128)
+                shm.buf[:len(magic_message)] = magic_message
+                torch.distributed.broadcast_object_list([shm.name],
+                                                        src=ranks[0],
+                                                        group=pg)
+                is_in_the_same_node[0] = 1
+            else:
+                # try to open the shared memory segment
+                recv = [None]
+                torch.distributed.broadcast_object_list(recv,
+                                                        src=ranks[0],
+                                                        group=pg)
+                name = recv[0]
+                shm = shared_memory.SharedMemory(name=name)
+                if shm.buf[:len(magic_message)] == magic_message:
+                    is_in_the_same_node[rank] = 1
+    except Exception as e:
+        logger.error("Error ignored in is_in_the_same_node: %s", e)
+    finally:
+        if shm:
+            shm.close()
+
+    torch.distributed.barrier(group=pg)
+
+    # clean up the shared memory segment
+    with contextlib.suppress(OSError):
+        if rank == 0:
+            if shm:
+                shm.unlink()
+        else:
+            if shm:
+                # fix to https://stackoverflow.com/q/62748654/9191338
+                resource_tracker.unregister(
+                    shm._name, "shared_memory")  # type: ignore[attr-defined]
+    torch.distributed.all_reduce(is_in_the_same_node, group=pg)
+
+    return is_in_the_same_node.sum().item() == world_size