1
0
Эх сурвалжийг харах

refactor custom allreduce to support multiple tp groups

AlpinDale 8 сар өмнө
parent
commit
b984fe4a91

+ 33 - 10
aphrodite/distributed/communication_op.py

@@ -1,5 +1,5 @@
 from collections import namedtuple
-from contextlib import contextmanager
+from contextlib import contextmanager, nullcontext
 from typing import Any, Dict, List, Optional, Tuple, Union
 
 import torch
@@ -9,11 +9,12 @@ from .parallel_state import (get_cpu_world_group,
                              get_tensor_model_parallel_group,
                              get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
+                             get_tp_ca_communicator,
                              get_tp_pynccl_communicator)
 
 
 @contextmanager
-def graph_capture_mode():
+def graph_mode():
     # In graph capture, we have to be very careful about the collective
     # operations. The current status is:
     #     allreduce \ Mode   |  Eager  |  Graph  |
@@ -24,10 +25,32 @@ def graph_capture_mode():
     #
     # Note that custom allreduce will have a runtime check, if the tensor size
     # is too large, it will fallback to the next available option.
+    # In summary: When using CUDA graph, we use
+    # either custom all-reduce kernel or pynccl. When not using CUDA
+    # graph, we use either custom all-reduce kernel or PyTorch NCCL.
+    # We always prioritize using custom all-reduce kernel but fall back
+    # to PyTorch or pynccl if it is disabled or not supported.
     pynccl_comm = get_tp_pynccl_communicator()
-    assert pynccl_comm is not None
-    with pynccl_comm.change_state(enable=True,
-                                  stream=torch.cuda.current_stream()):
+    if pynccl_comm is None:
+        context = nullcontext()
+    else:
+        context = pynccl_comm.change_state(enable=True,
+                                           stream=torch.cuda.current_stream())
+    with context:
+        yield
+
+
+@contextmanager
+def graph_capture():
+    """
+    `graph_capture` is a context manager which should include the code that
+    is capturing the CUDA graph. Its main purpose is to ensure that the
+    some operations will be run after the graph is captured, before the graph
+    is replayed.
+    """
+    ca_comm = get_tp_ca_communicator()
+    context = nullcontext() if ca_comm is None else ca_comm.capture()
+    with context:
         yield
 
 
@@ -42,15 +65,15 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
     TLDR: always assume this function modifies its input, but use the return
     value as the output.
     """
-    from aphrodite.distributed.device_communicators.custom_all_reduce import \
-        custom_all_reduce
+    ca_comm = get_tp_ca_communicator()
 
     # Bypass the function if we are using only 1 GPU.
     if get_tensor_model_parallel_world_size() == 1:
         return input_
-    out = custom_all_reduce(input_)
-    if out is not None:
-        return out
+    if ca_comm is not None:
+        out = ca_comm.custom_all_reduce(input_)
+        if out is not None:
+            return out
     pynccl_comm = get_tp_pynccl_communicator()
     if (pynccl_comm is not None and not pynccl_comm.disabled):
         pynccl_comm.all_reduce(input_)

+ 201 - 154
aphrodite/distributed/device_communicators/custom_all_reduce.py

@@ -1,164 +1,68 @@
+import os
 from contextlib import contextmanager
-from typing import Optional
-from loguru import logger
+from typing import Any, List, Optional, Union
 
 import torch
 import torch.distributed as dist
+from loguru import logger
+from torch.distributed import ProcessGroup
 
-from aphrodite.distributed import gpu_p2p_access_check
+from aphrodite.distributed.parallel_state import (
+    get_local_rank, get_tensor_model_parallel_cpu_group)
 
 try:
-    from aphrodite._C import custom_ar
     import pynvml
+
+    from aphrodite._C import custom_ar
+
+    @contextmanager
+    def _nvml():
+        try:
+            pynvml.nvmlInit()
+            yield
+        finally:
+            pynvml.nvmlShutdown()
+
 except ImportError:
     # For AMD GPUs
     custom_ar = None
     pynvml = None
 
-_CA_HANDLE = None
-_IS_CAPTURING = False
-_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
-
-
-def init_custom_ar() -> None:
-    from aphrodite.distributed import (get_tensor_model_parallel_rank,
-                                       get_tensor_model_parallel_world_size)
-    global _CA_HANDLE
-    if _CA_HANDLE is not None:
-        return
-    rank = get_tensor_model_parallel_rank()
-    world_size = get_tensor_model_parallel_world_size()
-    if world_size == 1:
-        return
-    if world_size not in _SUPPORTED_WORLD_SIZES:
-        logger.warning(
-            "Custom allreduce is disabled due to an unsupported world size: "
-            "%d. Supported world sizes: %s. To silence this warning, specify"
-            " disable_custom_all_reduce=True explicitly.", world_size,
-            str(_SUPPORTED_WORLD_SIZES))
-        return
-    num_dev = torch.cuda.device_count()
-    # note: num dev can be larger than world_size if we're only using
-    # first few GPUs
-    if num_dev < world_size:
-        logger.warning(
-            "Cannot test GPU P2P because not all GPUs are visible to the "
-            "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
-            " is set.")
-        return
-
-    # we only use a subset of GPUs here
-    # so we only need to check the nvlink connectivity of these GPUs
-    num_dev = world_size
-    # test nvlink first, this will filter out most of the cases
-    # where custom allreduce is not supported
-    full_nvlink = _is_full_nvlink(rank, world_size)
-    if world_size > 2 and not full_nvlink:
-        logger.warning(
-            "Custom allreduce is disabled because it's not supported on more"
-            " than two PCIe-only GPUs. To silence this warning, specify"
-            " disable_custom_all_reduce=True explicitly.")
-        return
-    # test P2P capability
-    # this is expensive to compute at the first time
-    # then we cache the result
-    if not _can_p2p(rank, world_size):
-        logger.warning(
-            "Custom allreduce is disabled because your platform lacks GPU P2P"
-            " capability or P2P test failed. To silence this warning, specify"
-            " disable_custom_all_reduce=True explicitly.")
-        return
-    _CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink)
-
-
-def begin_capture() -> None:
-    global _IS_CAPTURING
-    _IS_CAPTURING = True
-
-
-def end_capture() -> None:
-    global _IS_CAPTURING
-    _IS_CAPTURING = False
-
-
-def is_capturing() -> bool:
-    return _IS_CAPTURING and _CA_HANDLE is not None
-
-
-def get_handle() -> Optional["CustomAllreduce"]:
-    return _CA_HANDLE
-
-
-def is_initialized() -> bool:
-    return _CA_HANDLE is not None
-
-
-@contextmanager
-def capture():
-    try:
-        begin_capture()
-        yield
-    finally:
-        end_capture()
-        handle = get_handle()
-        if handle is not None:
-            handle.register_graph_buffers()
-
-
-# pylint: disable=redefined-builtin
-def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
-    ca_handle = get_handle()
-    # when custom allreduce is disabled, this will be None
-    if ca_handle is None:
-        return
-    if is_capturing():
-        if torch.cuda.is_current_stream_capturing():
-            if ca_handle.should_custom_ar(input):
-                return ca_handle.all_reduce_reg(input)
-        else:
-            if ca_handle.should_custom_ar(input):
-                # if warm up, mimic the allocation pattern
-                # since custom allreduce is out-of-place
-                return torch.empty_like(input)
-    else:
-        # NOTE: outside of cuda graph context,
-        # custom allreduce incurs a cost of cudaMemcpy, which should
-        # be small(<=1% of overall latency) compared to the performance
-        # gains of using custom kernels
-        if ca_handle.should_custom_ar(input):
-            return ca_handle.all_reduce_unreg(input)
-
-
-@contextmanager
-def _nvml():
-    try:
-        pynvml.nvmlInit()
-        yield
-    finally:
-        pynvml.nvmlShutdown()
-
-
-# query if the set of gpus are fully connected by nvlink (1 hop)
+    @contextmanager
+    def _nvml():
+        try:
+            yield
+        finally:
+            pass
+
+
 @_nvml()
-def _is_full_nvlink(rank, world_size):
-    handle = pynvml.nvmlDeviceGetHandleByIndex(rank)
-    for i in range(world_size):
-        if i != rank:
-            try:
-                peer_handle = pynvml.nvmlDeviceGetHandleByIndex(i)
-                p2p_status = pynvml.nvmlDeviceGetP2PStatus(
-                    handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
-                if p2p_status != pynvml.NVML_P2P_STATUS_OK:
+def _is_full_nvlink(device_ids: List[int]) -> bool:
+    """
+    query if the set of gpus are fully connected by nvlink (1 hop)
+    Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`,
+    so it works on real physical device ids.
+    """
+    handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
+    for i, handle in enumerate(handles):
+        for j, peer_handle in enumerate(handles):
+            if i < j:
+                try:
+                    p2p_status = pynvml.nvmlDeviceGetP2PStatus(
+                        handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
+                    if p2p_status != pynvml.NVML_P2P_STATUS_OK:
+                        return False
+                except pynvml.NVMLError as error:
+                    logger.error(
+                        "NVLink detection failed. This is normal if your"
+                        " machine has no NVLink equipped.",
+                        exc_info=error)
                     return False
-            except pynvml.NVMLError as error:
-                logger.info(
-                    f"NVLink detection failed with message \"{str(error)}\". "
-                    "This is normal if your machine has no NVLink equipped")
-                return False
     return True
 
 
 def _can_p2p(rank: int, world_size: int) -> bool:
+    from aphrodite.distributed.utils import gpu_p2p_access_check
     for i in range(world_size):
         if i == rank:
             continue
@@ -169,22 +73,112 @@ def _can_p2p(rank: int, world_size: int) -> bool:
 
 class CustomAllreduce:
 
+    _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
+
     # max_size: max supported allreduce size
     def __init__(self,
-                 rank,
-                 world_size,
-                 full_nvlink,
+                 group: Optional[ProcessGroup] = None,
+                 device: Optional[Union[int, str, torch.device]] = None,
                  max_size=8192 * 1024) -> None:
+        """
+        Args:
+            group: the process group to work on. If None, it will use the
+                default process group.
+            device: the device to bind the CustomAllreduce to. If None,
+                it will be bind to f"cuda:{local_rank}".
+        It is the caller's responsibility to make sure each communicator
+        is bind to a unique device, and all communicators in this group
+        are in the same node.
+        """
+        self._IS_CAPTURING = False
+        self.disabled = True
+
+        if custom_ar is None:
+            # disable because of missing custom allreduce library
+            # e.g. in a non-cuda environment
+            return
+
+        group = group or get_tensor_model_parallel_cpu_group()
+        self.group = group
+
+        assert dist.get_backend(group) != dist.Backend.NCCL, (
+            "CustomAllreduce should be attached to a non-NCCL group.")
+
+        rank = dist.get_rank(group=self.group)
+        world_size = dist.get_world_size(group=self.group)
+        if world_size == 1:
+            # No need to initialize custom allreduce for single GPU case.
+            return
+
+        if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
+            logger.warning(
+                "Custom allreduce is disabled due to an unsupported world"
+                " size: %d. Supported world sizes: %s. To silence this "
+                "warning, specify disable_custom_all_reduce=True explicitly.",
+                world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
+            return
+
+        if device is None:
+            local_rank = get_local_rank()
+            device = torch.device(f"cuda:{local_rank}")
+        elif isinstance(device, int):
+            device = torch.device(f"cuda:{device}")
+        elif isinstance(device, str):
+            device = torch.device(device)
+        # now `device` is a `torch.device` object
+        assert isinstance(device, torch.device)
+        self.device = device
+
+        cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+        if cuda_visible_devices:
+            device_ids = list(map(int, cuda_visible_devices.split(",")))
+        else:
+            device_ids = list(range(torch.cuda.device_count()))
+
+        physical_device_id = device_ids[device.index]
+        tensor = torch.tensor([physical_device_id],
+                              dtype=torch.int,
+                              device="cpu")
+        gather_list = [
+            torch.tensor([0], dtype=torch.int, device="cpu")
+            for _ in range(world_size)
+        ]
+        dist.all_gather(gather_list, tensor, group=self.group)
+        physical_device_ids = [t.item() for t in gather_list]
+
+        # test nvlink first, this will filter out most of the cases
+        # where custom allreduce is not supported
+        # this checks hardware and driver support for NVLink
+        full_nvlink = _is_full_nvlink(physical_device_ids)
+        if world_size > 2 and not full_nvlink:
+            logger.warning(
+                "Custom allreduce is disabled because it's not supported on"
+                " more than two PCIe-only GPUs. To silence this warning, "
+                "specify disable_custom_all_reduce=True explicitly.")
+            return
+        # test P2P capability, this checks software/cudaruntime support
+        # this is expensive to compute at the first time
+        # then we cache the result
+        if not _can_p2p(rank, world_size):
+            logger.warning(
+                "Custom allreduce is disabled because your platform lacks "
+                "GPU P2P capability or P2P test failed. To silence this "
+                "warning, specify disable_custom_all_reduce=True explicitly.")
+            return
+
+        self.disabled = False
         # buffers memory are owned by this Python class and passed to C++
         # meta data composes of two parts: meta data for synchronization
         # (256 bytes) and a temporary buffer for storing intermediate
         # allreduce results.
         self.meta = torch.zeros(custom_ar.meta_size() + max_size,
                                 dtype=torch.uint8,
-                                device="cuda")
+                                device=self.device)
         # This is a pre-registered IPC buffer. In eager mode, input tensors
         # are first copied into this buffer before allreduce is performed
-        self.buffer = torch.empty(max_size, dtype=torch.uint8, device="cuda")
+        self.buffer = torch.empty(max_size,
+                                  dtype=torch.uint8,
+                                  device=self.device)
         # This is a buffer for storing the tuples of pointers pointing to
         # IPC buffers from all ranks. Each registered tuple has size of
         # 8*world_size bytes where world_size is at most 8. Allocating 8MB
@@ -192,8 +186,9 @@ class CustomAllreduce:
         # needs less than 10000 of registered tuples.
         self.rank_data = torch.empty(8 * 1024 * 1024,
                                      dtype=torch.uint8,
-                                     device="cuda")
+                                     device=self.device)
         self.max_size = max_size
+        self.rank = rank
         self.world_size = world_size
         handles, offsets = self._get_ipc_meta(self.meta)
         self.full_nvlink = full_nvlink
@@ -202,8 +197,22 @@ class CustomAllreduce:
                                              self.full_nvlink)
         self.register_buffer(self.buffer)
 
+    @contextmanager
+    def capture(self):
+        """
+        The main responsibility of this context manager is the 
+        `register_graph_buffers` call at the end of the context.
+        It records all the buffer addresses used in the CUDA graph.
+        """
+        try:
+            self._IS_CAPTURING = True
+            yield
+        finally:
+            self._IS_CAPTURING = False
+            if not self.disabled:
+                self.register_graph_buffers()
+
     def _get_ipc_meta(self, inp: torch.Tensor):
-        # pylint: disable=protected-access
         data = inp.untyped_storage()._share_cuda_()
         shard_data = (
             data[1],  # ipc handle to base ptr
@@ -212,14 +221,29 @@ class CustomAllreduce:
         return self._gather_ipc_meta(shard_data)
 
     def _gather_ipc_meta(self, shard_data):
-        all_data = [None] * self.world_size
-        dist.all_gather_object(all_data, shard_data)
+        # Note: don't use `[[None]] * self.world_size` here
+        # because it will create a list of the same reference
+        all_data: List[Optional[Any]] = [[None]
+                                         for i in range(self.world_size)]
+        all_data[self.rank][0] = shard_data
+
+        ranks = dist.get_process_group_ranks(group=self.group)
+        ranks.sort()
+        for i, rank in enumerate(ranks):
+            dist.broadcast_object_list(all_data[i],
+                                       src=rank,
+                                       group=self.group,
+                                       device="cpu")
+
+        # we cannot directly use `dist.all_gather_object` here
+        # because it is incompatible with `gloo` backend under inference mode.
+        # see https://github.com/pytorch/pytorch/issues/126032 for details.
 
         handles = []
         offsets = []
         for i in range(len(all_data)):
-            handles.append(all_data[i][0])
-            offsets.append(all_data[i][1])
+            handles.append(all_data[i][0][0])  # type: ignore
+            offsets.append(all_data[i][0][1])  # type: ignore
         return handles, offsets
 
     def register_buffer(self, inp: torch.Tensor):
@@ -251,8 +275,31 @@ class CustomAllreduce:
         custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out)
         return out
 
+    def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
+        # when custom allreduce is disabled, this will be None
+        if self.disabled:
+            return None
+        if self._IS_CAPTURING:
+            if torch.cuda.is_current_stream_capturing():
+                if self.should_custom_ar(input):
+                    return self.all_reduce_reg(input)
+            else:
+                if self.should_custom_ar(input):
+                    # if warm up, mimic the allocation pattern
+                    # since custom allreduce is out-of-place
+                    return torch.empty_like(input)
+        else:
+            # note: outside of cuda graph context,
+            # custom allreduce incurs a cost of cudaMemcpy, which should
+            # be small(<=1% of overall latency) compared to the performance
+            # gains of using custom kernels
+            if self.should_custom_ar(input):
+                return self.all_reduce_unreg(input)
+
+        return None
+
     def close(self):
-        if self._ptr:
+        if not self.disabled and self._ptr:
             custom_ar.dispose(self._ptr)
             self._ptr = 0
 

+ 3 - 1
aphrodite/distributed/device_communicators/pynccl.py

@@ -95,8 +95,10 @@ class PyNcclCommunicator:
             self.stream = torch.cuda.Stream()
 
             # A small all_reduce for warmup.
-            self.all_reduce(torch.zeros(1, device=device))
+            data = torch.zeros(1, device=device)
+            self.all_reduce(data)
             self.stream.synchronize()
+            del data
 
         # by default it is disabled, e.g. in profiling models and prefill phase.
         # to use it, use under `with obj.change_state(enable=True)`, usually

+ 257 - 0
aphrodite/distributed/device_communicators/pynccl_wrapper.py

@@ -0,0 +1,257 @@
+# This file is a pure Python wrapper for the NCCL library.
+# The main purpose is to use NCCL combined with CUDA graph.
+# Before writing this script, we tried the following approach:
+# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
+#  often gets stuck when initializing the NCCL communicator.
+# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
+#  contains many other potential cuda APIs, that are not allowed during
+#  capturing the CUDA graph. For further details, please check
+# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
+#
+# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
+# doable, but we often encounter issues related with nccl versions, and need
+# to switch between different versions of NCCL. See
+# https://github.com/NVIDIA/nccl/issues/1234 for more details.
+# A C/C++ binding is not flexible enough to handle this. It requires
+# recompilation of the code every time we want to switch between different
+# versions. This current implementation, with a **pure** Python wrapper, is
+# more flexible. We can easily switch between different versions of NCCL by
+# changing the environment variable `APHRODITE_NCCL_SO_PATH`, or the `so_file`
+# variable in the code.
+
+import ctypes
+import platform
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional
+
+import torch
+from loguru import logger
+from torch.distributed import ReduceOp
+
+from aphrodite.common.utils import find_nccl_library, nccl_integrity_check
+
+# === export types and functions from nccl to Python ===
+# for the original nccl definition, please check
+# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
+
+ncclResult_t = ctypes.c_int
+ncclComm_t = ctypes.c_void_p
+
+
+class ncclUniqueId(ctypes.Structure):
+    _fields_ = [("internal", ctypes.c_byte * 128)]
+
+
+cudaStream_t = ctypes.c_void_p
+buffer_type = ctypes.c_void_p
+
+ncclDataType_t = ctypes.c_int
+
+
+class ncclDataTypeEnum:
+    ncclInt8 = 0
+    ncclChar = 0
+    ncclUint8 = 1
+    ncclInt32 = 2
+    ncclInt = 2
+    ncclUint32 = 3
+    ncclInt64 = 4
+    ncclUint64 = 5
+    ncclFloat16 = 6
+    ncclHalf = 6
+    ncclFloat32 = 7
+    ncclFloat = 7
+    ncclFloat64 = 8
+    ncclDouble = 8
+    ncclBfloat16 = 9
+    ncclNumTypes = 10
+
+    @classmethod
+    def from_torch(cls, dtype: torch.dtype) -> int:
+        if dtype == torch.int8:
+            return cls.ncclInt8
+        if dtype == torch.uint8:
+            return cls.ncclUint8
+        if dtype == torch.int32:
+            return cls.ncclInt32
+        if dtype == torch.int64:
+            return cls.ncclInt64
+        if dtype == torch.float16:
+            return cls.ncclFloat16
+        if dtype == torch.float32:
+            return cls.ncclFloat32
+        if dtype == torch.float64:
+            return cls.ncclFloat64
+        if dtype == torch.bfloat16:
+            return cls.ncclBfloat16
+        raise ValueError(f"Unsupported dtype: {dtype}")
+
+
+ncclRedOp_t = ctypes.c_int
+
+
+class ncclRedOpTypeEnum:
+    ncclSum = 0
+    ncclProd = 1
+    ncclMax = 2
+    ncclMin = 3
+    ncclAvg = 4
+    ncclNumOps = 5
+
+    @classmethod
+    def from_torch(cls, op: ReduceOp) -> int:
+        if op == ReduceOp.SUM:
+            return cls.ncclSum
+        if op == ReduceOp.PRODUCT:
+            return cls.ncclProd
+        if op == ReduceOp.MAX:
+            return cls.ncclMax
+        if op == ReduceOp.MIN:
+            return cls.ncclMin
+        if op == ReduceOp.AVG:
+            return cls.ncclAvg
+        raise ValueError(f"Unsupported op: {op}")
+
+
+@dataclass
+class Function:
+    name: str
+    restype: Any
+    argtypes: List[Any]
+
+
+class NCCLLibrary:
+    exported_functions = [
+        # const char* ncclGetErrorString(ncclResult_t result)
+        Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
+        # ncclResult_t  ncclGetVersion(int *version);
+        Function("ncclGetVersion", ncclResult_t,
+                 [ctypes.POINTER(ctypes.c_int)]),
+        # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
+        Function("ncclGetUniqueId", ncclResult_t,
+                 [ctypes.POINTER(ncclUniqueId)]),
+        # ncclResult_t  ncclCommInitRank(
+        #   ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
+        # note that ncclComm_t is a pointer type, so the first argument
+        # is a pointer to a pointer
+        Function("ncclCommInitRank", ncclResult_t, [
+            ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId,
+            ctypes.c_int
+        ]),
+        # ncclResult_t  ncclAllReduce(
+        #   const void* sendbuff, void* recvbuff, size_t count,
+        #   ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
+        #   cudaStream_t stream);
+        # note that cudaStream_t is a pointer type, so the last argument
+        # is a pointer
+        Function("ncclAllReduce", ncclResult_t, [
+            buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
+            ncclRedOp_t, ncclComm_t, cudaStream_t
+        ]),
+
+        # be cautious! this is a collective call, it will block until all
+        # processes in the communicator have called this function.
+        # because Python object destruction can happen in random order,
+        # it is better not to call it at all.
+        # ncclResult_t  ncclCommDestroy(ncclComm_t comm);
+        Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
+    ]
+
+    # class attribute to store the mapping from the path to the library
+    # to avoid loading the same library multiple times
+    path_to_library_cache: Dict[str, Any] = {}
+
+    # class attribute to store the mapping from library path
+    #  to the corresponding dictionary
+    path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
+
+    def __init__(self, so_file: Optional[str] = None):
+
+        so_file = so_file or find_nccl_library()
+
+        try:
+            # load the library in another process.
+            # if it core dumps, it will not crash the current process
+            nccl_integrity_check(so_file)
+        except Exception as e:
+            logger.error(
+                f"Failed to load NCCL library from {so_file} ."
+                "It is expected if you are not running on NVIDIA/AMD GPUs."
+                "Otherwise, the nccl library might not exist, be corrupted "
+                "or it does not support the current platform "
+                f"{platform.platform()}."
+                "One solution is to download libnccl2 version 2.18 from "
+                "https://developer.download.nvidia.com/compute/cuda/repos/ "
+                "and extract the libnccl.so.2 file. If you already have the "
+                "library, please set the environment variable "
+                "APHRODITE_NCCL_SO_PATH to point to the correct nccl "
+                "library path.")
+            raise e
+
+        if so_file not in NCCLLibrary.path_to_dict_mapping:
+            lib = ctypes.CDLL(so_file)
+            NCCLLibrary.path_to_library_cache[so_file] = lib
+        self.lib = NCCLLibrary.path_to_library_cache[so_file]
+
+        if so_file not in NCCLLibrary.path_to_dict_mapping:
+            _funcs = {}
+            for func in NCCLLibrary.exported_functions:
+                f = getattr(self.lib, func.name)
+                f.restype = func.restype
+                f.argtypes = func.argtypes
+                _funcs[func.name] = f
+            NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
+        self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]
+
+    def ncclGetErrorString(self, result: ncclResult_t) -> str:
+        return self._funcs["ncclGetErrorString"](result).decode("utf-8")
+
+    def NCCL_CHECK(self, result: ncclResult_t) -> None:
+        if result != 0:
+            error_str = self.ncclGetErrorString(result)
+            raise RuntimeError(f"NCCL error: {error_str}")
+
+    def ncclGetVersion(self) -> str:
+        version = ctypes.c_int()
+        self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
+        version_str = str(version.value)
+        # something like 21903 --> "2.19.3"
+        major = version_str[0].lstrip("0")
+        minor = version_str[1:3].lstrip("0")
+        patch = version_str[3:].lstrip("0")
+        return f"{major}.{minor}.{patch}"
+
+    def ncclGetUniqueId(self) -> ncclUniqueId:
+        unique_id = ncclUniqueId()
+        self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](
+            ctypes.byref(unique_id)))
+        return unique_id
+
+    def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
+                         rank: int) -> ncclComm_t:
+        comm = ncclComm_t()
+        self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm),
+                                                        world_size, unique_id,
+                                                        rank))
+        return comm
+
+    def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
+                      count: int, datatype: int, op: int, comm: ncclComm_t,
+                      stream: cudaStream_t) -> None:
+        # `datatype` actually should be `ncclDataType_t`
+        # and `op` should be `ncclRedOp_t`
+        # both are aliases of `ctypes.c_int`
+        # when we pass int to a function, it will be converted to `ctypes.c_int`
+        # by ctypes automatically
+        self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count,
+                                                     datatype, op, comm,
+                                                     stream))
+
+    def ncclCommDestroy(self, comm: ncclComm_t) -> None:
+        self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
+
+
+__all__ = [
+    "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
+    "ncclComm_t", "cudaStream_t", "buffer_type"
+]

+ 26 - 1
aphrodite/distributed/parallel_state.py

@@ -11,10 +11,12 @@ import torch
 from loguru import logger
 from torch.distributed import ProcessGroup
 
+_ENABLE_CUSTOM_ALL_REDUCE = True
 # Tensor model parallel group that the current rank belongs to.
 _TP_DEVICE_GROUP: Optional[ProcessGroup] = None
 _TP_CPU_GROUP: Optional[ProcessGroup] = None
 _TP_PYNCCL_COMMUNICATOR = None
+_TP_CA_COMMUNICATOR = None
 # Pipeline model parallel group that the current rank belongs to.
 _PP_DEVICE_GROUP: Optional[ProcessGroup] = None
 
@@ -45,11 +47,21 @@ _PP_GLOBAL_RANKS: Optional[List[int]] = None
 _LOCAL_RANK = -1
 
 
+def set_custom_all_reduce(enable: bool):
+    global _ENABLE_CUSTOM_ALL_REDUCE
+    _ENABLE_CUSTOM_ALL_REDUCE = enable
+
+
 def get_tp_pynccl_communicator():
     global _TP_PYNCCL_COMMUNICATOR
     return _TP_PYNCCL_COMMUNICATOR
 
 
+def get_tp_ca_communicator():
+    global _TP_CA_COMMUNICATOR
+    return _TP_CA_COMMUNICATOR
+
+
 def get_local_rank():
     global _LOCAL_RANK
     return _LOCAL_RANK
@@ -93,6 +105,9 @@ def init_distributed_environment(
         if torch.cuda.is_available():
             data = data.to(device=f"cuda:{local_rank}")
         torch.distributed.all_reduce(data)
+        if torch.cuda.is_available():
+            torch.cuda.synchronize()
+        del data
 
 
 def initialize_model_parallel(
@@ -142,7 +157,8 @@ def initialize_model_parallel(
     rank = torch.distributed.get_rank()
 
     # Build the tensor model-parallel groups.
-    global _TP_DEVICE_GROUP, _TP_CPU_GROUP, _TP_PYNCCL_COMMUNICATOR
+    global _TP_DEVICE_GROUP, _TP_CPU_GROUP
+    global _TP_PYNCCL_COMMUNICATOR, _TP_CA_COMMUNICATOR
     assert _TP_DEVICE_GROUP is None, (
         "tensor model parallel group is already initialized")
     for i in range(num_tensor_model_parallel_groups):
@@ -162,6 +178,15 @@ def initialize_model_parallel(
         device=_LOCAL_RANK,
     )
 
+    # Initialize a custom fast all-reduce implementation.
+    if _ENABLE_CUSTOM_ALL_REDUCE:
+        from aphrodite.distributed.device_communicators.custom_all_reduce \
+            import CustomAllreduce
+        _TP_CA_COMMUNICATOR = CustomAllreduce(
+            group=_TP_CPU_GROUP,
+            device=_LOCAL_RANK,
+        )
+
     # Build the pipeline model-parallel groups.
     global _PP_DEVICE_GROUP
     global _PP_GLOBAL_RANKS

+ 4 - 11
aphrodite/task_handler/model_runner.py

@@ -20,8 +20,7 @@ from aphrodite.common.utils import (CudaMemoryProfiler,
                                     is_pin_memory_available,
                                     make_tensor_with_pad)
 from aphrodite.distributed import broadcast_tensor_dict
-from aphrodite.distributed.communication_op import graph_capture_mode
-from aphrodite.distributed.device_communicators import custom_all_reduce
+from aphrodite.distributed.communication_op import graph_capture, graph_mode
 from aphrodite.distributed.parallel_state import \
     get_tensor_model_parallel_world_size
 from aphrodite.lora.layers import LoRAMapping
@@ -950,13 +949,7 @@ class ModelRunner:
             bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
         ]
 
-        # NOTE: There are 3 backends for all-reduce: custom all-reduce
-        # kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use
-        # either custom all-reduce kernel or pynccl. When not using CUDA
-        # graph, we use either custom all-reduce kernel or PyTorch NCCL.
-        # We always prioritize using custom all-reduce kernel but fall back
-        # to PyTorch or pynccl if it is disabled or not supported.
-        with custom_all_reduce.capture():
+        with graph_capture():
             # NOTE: Capturing the largest batch size first may help reduce the
             # memory usage of CUDA graph.
             for batch_size in reversed(batch_size_capture_list):
@@ -1048,7 +1041,7 @@ class CUDAGraphRunner:
         # Run the model once without capturing the graph.
         # This is to make sure that the captured graph does not include the
         # kernel launches for initial benchmarking (e.g., Triton autotune).
-        with graph_capture_mode():
+        with graph_mode():
             self.model(
                 input_ids,
                 positions,
@@ -1063,7 +1056,7 @@ class CUDAGraphRunner:
         # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
         self._graph = torch.cuda.CUDAGraph()
         with torch.cuda.graph(self._graph, pool=memory_pool):  # noqa: SIM117
-            with graph_capture_mode():
+            with graph_mode():
                 hidden_states = self.model(
                     input_ids,
                     positions,

+ 3 - 7
aphrodite/task_handler/worker.py

@@ -14,9 +14,8 @@ from aphrodite.common.sequence import (ExecuteModelRequest, PoolerOutput,
                                        SamplerOutput)
 from aphrodite.distributed import (broadcast_tensor_dict,
                                    ensure_model_parallel_initialized,
-                                   init_distributed_environment)
-from aphrodite.distributed.device_communicators.custom_all_reduce import \
-    init_custom_ar
+                                   init_distributed_environment,
+                                   set_custom_all_reduce)
 from aphrodite.lora.request import LoRARequest
 from aphrodite.modeling import set_random_seed
 from aphrodite.task_handler.cache_engine import CacheEngine
@@ -304,16 +303,13 @@ def init_worker_distributed_environment(
     local_rank: int = -1,
 ) -> None:
     """Initialize the distributed environment."""
+    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
     init_distributed_environment(parallel_config.world_size, rank,
                                  distributed_init_method, local_rank)
 
     ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
                                       parallel_config.pipeline_parallel_size)
 
-    # Initialize a custom fast all-reduce implementation.
-    if not parallel_config.disable_custom_all_reduce:
-        init_custom_ar()
-
 
 def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
     # Check if the GPU supports the dtype.