Browse Source

chore: init nccl using the gloo backend

AlpinDale 8 months ago
parent
commit
4c746d8baa

+ 78 - 66
aphrodite/distributed/device_communicators/pynccl.py

@@ -23,11 +23,15 @@ import ctypes
 import datetime
 import logging
 import os
+from typing import Optional, Union
 
 # ===================== import region =====================
 import torch
 import torch.distributed as dist
-from torch.distributed import ReduceOp
+from torch.distributed import ProcessGroup, ReduceOp
+
+from aphrodite.distributed.parallel_state import (get_cpu_world_group,
+                                                  get_local_rank)
 
 logger = logging.getLogger(__name__)
 
@@ -63,6 +67,15 @@ except Exception as e:
 
 ncclResult_t = ctypes.c_int
 
+_c_ncclGetErrorString = nccl.ncclGetErrorString
+_c_ncclGetErrorString.restype = ctypes.c_char_p
+_c_ncclGetErrorString.argtypes = [ncclResult_t]
+
+def NCCL_CHECK(result: ncclResult_t) -> None:
+    if result != 0:
+        error_str = _c_ncclGetErrorString(result).decode("utf-8")
+        raise RuntimeError(f"NCCL error: {error_str}")
+
 # equivalent to c declaration:
 # ncclResult_t  ncclGetVersion(int *version);
 _c_ncclGetVersion = nccl.ncclGetVersion
@@ -72,8 +85,7 @@ _c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
 
 def ncclGetVersion() -> str:
     version = ctypes.c_int()
-    result = _c_ncclGetVersion(ctypes.byref(version))
-    assert result == 0
+    NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version)))
     # something like 21903 --> "2.19.3"
     version_str = str(version.value)
     major = version_str[0].lstrip("0")
@@ -95,8 +107,7 @@ _c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
 
 def ncclGetUniqueId() -> NcclUniqueId:
     unique_id = NcclUniqueId()
-    result = _c_ncclGetUniqueId(ctypes.byref(unique_id))
-    assert result == 0
+    NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id)))
     return unique_id
 
 
@@ -110,10 +121,8 @@ _c_ncclCommInitRank.restype = ctypes.c_int
 _c_ncclCommInitRank.argtypes = [
     ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
 ]
-
-
-# enums
-class ncclDataType_t(ctypes.c_int):
+ncclDataType_t = ctypes.c_int
+class ncclDataTypeEnum:
     ncclInt8 = 0
     ncclChar = 0
     ncclUint8 = 1
@@ -130,9 +139,8 @@ class ncclDataType_t(ctypes.c_int):
     ncclDouble = 8
     ncclBfloat16 = 9
     ncclNumTypes = 10
-
     @classmethod
-    def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t':
+    def from_torch(cls, dtype: torch.dtype) -> int:
         if dtype == torch.int8:
             return cls.ncclInt8
         if dtype == torch.uint8:
@@ -150,18 +158,16 @@ class ncclDataType_t(ctypes.c_int):
         if dtype == torch.bfloat16:
             return cls.ncclBfloat16
         raise ValueError(f"Unsupported dtype: {dtype}")
-
-
-class ncclRedOp_t(ctypes.c_int):
+ncclRedOp_t = ctypes.c_int
+class ncclRedOpTypeEnum:
     ncclSum = 0
     ncclProd = 1
     ncclMax = 2
     ncclMin = 3
     ncclAvg = 4
     ncclNumOps = 5
-
     @classmethod
-    def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t':
+    def from_torch(cls, op: ReduceOp) -> int:
         if op == ReduceOp.SUM:
             return cls.ncclSum
         if op == ReduceOp.PRODUCT:
@@ -173,8 +179,6 @@ class ncclRedOp_t(ctypes.c_int):
         if op == ReduceOp.AVG:
             return cls.ncclAvg
         raise ValueError(f"Unsupported op: {op}")
-
-
 # equivalent to c declaration:
 # ncclResult_t  ncclAllReduce(
 #   const void* sendbuff, void* recvbuff, size_t count,
@@ -184,10 +188,9 @@ class ncclRedOp_t(ctypes.c_int):
 _c_ncclAllReduce = nccl.ncclAllReduce
 _c_ncclAllReduce.restype = ctypes.c_int
 _c_ncclAllReduce.argtypes = [
-    ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t,
-    ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p
+    ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t,
+    ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p
 ]
-
 # equivalent to c declaration:
 # ncclResult_t  ncclCommDestroy(ncclComm_t comm);
 _c_ncclCommDestroy = nccl.ncclCommDestroy
@@ -199,66 +202,75 @@ class NCCLCommunicator:
 
     def __init__(
         self,
-        backend=None,
-        init_method=None,
-        timeout=datetime.timedelta(seconds=10),
-        world_size: int = -1,
-        rank: int = -1,
-        store=None,
-        group_name: str = "",
-        pg_options=None,
-        local_rank: int = -1,
+        group: Optional[ProcessGroup] = None,
+        device: Optional[Union[int, str, torch.device]] = None,
     ):
-        if not dist.is_initialized():
-            backend = backend or "nccl"
-            assert backend == 'nccl', (
-                "only use nccl backend for starting the NCCL communicator")
-            dist.init_process_group(backend=backend,
-                                    init_method=init_method,
-                                    timeout=timeout,
-                                    world_size=world_size,
-                                    rank=rank,
-                                    store=store,
-                                    group_name=group_name,
-                                    pg_options=pg_options)
-        self.rank = dist.get_rank()
-        self.world_size = dist.get_world_size()
-        if local_rank == -1:
-            local_rank = self.rank
-        self.local_rank = local_rank
-        # don't use these args, as they can be -1
-        # use `self.rank`, `self.local_rank` and `self.world_size` instead
-        del world_size, rank, local_rank
-        torch.cuda.set_device(self.local_rank)
+        assert dist.is_initialized()
+        group = get_cpu_world_group() if group is None else group
+        assert dist.get_backend(group) != dist.Backend.NCCL, (
+            "NCCLCommunicator should be attached to a non-NCCL group.")
+        self.group = group
+        self.rank = dist.get_rank(group)
+        self.world_size = dist.get_world_size(group)
         if self.rank == 0:
             self.unique_id = ncclGetUniqueId()
         else:
             self.unique_id = NcclUniqueId()
-        tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda(
-            self.local_rank)
-        dist.broadcast(tensor, src=0)
-        byte_list = tensor.cpu().tolist()
+        tensor = torch.ByteTensor(list(self.unique_id.internal))
+        dist.broadcast(tensor, src=0, group=group)
+        byte_list = tensor.tolist()
         for i, byte in enumerate(byte_list):
             self.unique_id.internal[i] = byte
         self.comm = ctypes.c_void_p()
-        result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
-                                     self.unique_id, self.rank)
-        assert result == 0
-        self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}")
+        # result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
+        #                              self.unique_id, self.rank)
+        # assert result == 0
+        # self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}")
+        if device is None:
+            local_rank = get_local_rank()
+            device = torch.device(f"cuda:{local_rank}")
+        elif isinstance(device, int):
+            device = torch.device(f"cuda:{device}")
+        elif isinstance(device, str):
+            device = torch.device(device)
+        # now the `device` object is a `torch.device` object
+        assert isinstance(device, torch.device)
+        self.device = device
+        current_device = torch.cuda.current_device()
+        try:
+            torch.cuda.set_device(self.device)
+            NCCL_CHECK(_c_ncclCommInitRank(ctypes.byref(self.comm),
+                                           self.world_size, self.unique_id,
+                                           self.rank))
+            self.stream = torch.cuda.Stream()
+        finally:
+            torch.cuda.set_device(current_device)
 
     def all_reduce(self,
                    tensor: torch.Tensor,
                    op: ReduceOp = ReduceOp.SUM,
                    stream=None):
+        # nccl communicator created on a specific device will only work
+        # on tensors on the same device, otherwise it'll cause
+        # illegal memory access
+        assert tensor.device == self.device, (
+            f"tensor.device={tensor.device} should be {self.device}")
         if stream is None:
             stream = self.stream
-        result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
-                                  ctypes.c_void_p(tensor.data_ptr()),
-                                  tensor.numel(),
-                                  ncclDataType_t.from_torch(tensor.dtype),
-                                  ncclRedOp_t.from_torch(op), self.comm,
-                                  ctypes.c_void_p(stream.cuda_stream))
-        assert result == 0
+        # result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
+        #                           ctypes.c_void_p(tensor.data_ptr()),
+        #                           tensor.numel(),
+        #                           ncclDataType_t.from_torch(tensor.dtype),
+        #                           ncclRedOp_t.from_torch(op), self.comm,
+        #                           ctypes.c_void_p(stream.cuda_stream))
+        NCCL_CHECK(
+            _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
+                             ctypes.c_void_p(tensor.data_ptr()),
+                             tensor.numel(),
+                             ncclDataTypeEnum.from_torch(tensor.dtype),
+                             ncclRedOpTypeEnum.from_torch(op), self.comm,
+                             ctypes.c_void_p(stream.cuda_stream))
+        )
 
     def __del__(self):
         # `dist` module might have been already destroyed

+ 4 - 12
aphrodite/distributed/device_communicators/pynccl_utils.py

@@ -3,13 +3,11 @@ from typing import Optional
 
 import torch
 from loguru import logger
-from torch.distributed import ReduceOp
+from torch.distributed import ProcessGroup, ReduceOp
 
 try:
     from aphrodite.distributed.device_communicators.pynccl import (
-        NCCLCommunicator,
-        ncclGetVersion,
-    )
+        NCCLCommunicator, ncclGetVersion)
 except Exception as e:
     # in non-NVIDIA environments, we can't import the nccl module
     # e.g. when running on machines with AMD GPUs
@@ -35,17 +33,11 @@ def set_pynccl_stream(stream: torch.cuda.Stream):
         pass
 
 
-def init_process_group(world_size: int,
-                       rank: int,
-                       init_method: str,
-                       local_rank: int = -1) -> None:
+def init_process_group(group: Optional[ProcessGroup] = None) -> None:
     assert not is_initialized()
     global comm
     logger.info(f"Aphrodite is using nccl=={ncclGetVersion()}")
-    comm = NCCLCommunicator(init_method=init_method,
-                            world_size=world_size,
-                            local_rank=local_rank,
-                            rank=rank)
+    comm = NCCLCommunicator(group=group)
 
 
 def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:

+ 4 - 0
aphrodite/distributed/parallel_state.py

@@ -5,6 +5,7 @@
 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
 """Tensor and pipeline parallel groups."""
 import contextlib
+import os
 from typing import Optional
 
 import torch
@@ -71,6 +72,9 @@ def init_distributed_environment(
         ranks = list(range(torch.distributed.get_world_size()))
         _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
                                                        backend="gloo")
+        if local_rank == -1 and distributed_init_method == "env://":
+            # Get local rank from environment variable.
+            local_rank = int(os.environ.get["LOCAL_RANK"])
         global _LOCAL_RANK
         _LOCAL_RANK = local_rank
 

+ 1 - 6
aphrodite/task_handler/worker.py

@@ -299,12 +299,7 @@ def init_worker_distributed_environment(
     elif parallel_config.world_size > 1:
         # NOTE: We don't initialize pynccl process group when world size
         # is 1.
-        pynccl_utils.init_process_group(
-            world_size=parallel_config.world_size,
-            local_rank=local_rank,
-            rank=rank,
-            init_method=distributed_init_method,
-        )
+        pynccl_utils.init_process_group()
 
     ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
                                       parallel_config.pipeline_parallel_size)