|
@@ -23,11 +23,15 @@ import ctypes
|
|
|
import datetime
|
|
|
import logging
|
|
|
import os
|
|
|
+from typing import Optional, Union
|
|
|
|
|
|
# ===================== import region =====================
|
|
|
import torch
|
|
|
import torch.distributed as dist
|
|
|
-from torch.distributed import ReduceOp
|
|
|
+from torch.distributed import ProcessGroup, ReduceOp
|
|
|
+
|
|
|
+from aphrodite.distributed.parallel_state import (get_cpu_world_group,
|
|
|
+ get_local_rank)
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
@@ -63,6 +67,15 @@ except Exception as e:
|
|
|
|
|
|
ncclResult_t = ctypes.c_int
|
|
|
|
|
|
+_c_ncclGetErrorString = nccl.ncclGetErrorString
|
|
|
+_c_ncclGetErrorString.restype = ctypes.c_char_p
|
|
|
+_c_ncclGetErrorString.argtypes = [ncclResult_t]
|
|
|
+
|
|
|
+def NCCL_CHECK(result: ncclResult_t) -> None:
|
|
|
+ if result != 0:
|
|
|
+ error_str = _c_ncclGetErrorString(result).decode("utf-8")
|
|
|
+ raise RuntimeError(f"NCCL error: {error_str}")
|
|
|
+
|
|
|
# equivalent to c declaration:
|
|
|
# ncclResult_t ncclGetVersion(int *version);
|
|
|
_c_ncclGetVersion = nccl.ncclGetVersion
|
|
@@ -72,8 +85,7 @@ _c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
|
|
|
|
|
|
def ncclGetVersion() -> str:
|
|
|
version = ctypes.c_int()
|
|
|
- result = _c_ncclGetVersion(ctypes.byref(version))
|
|
|
- assert result == 0
|
|
|
+ NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version)))
|
|
|
# something like 21903 --> "2.19.3"
|
|
|
version_str = str(version.value)
|
|
|
major = version_str[0].lstrip("0")
|
|
@@ -95,8 +107,7 @@ _c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
|
|
|
|
|
|
def ncclGetUniqueId() -> NcclUniqueId:
|
|
|
unique_id = NcclUniqueId()
|
|
|
- result = _c_ncclGetUniqueId(ctypes.byref(unique_id))
|
|
|
- assert result == 0
|
|
|
+ NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id)))
|
|
|
return unique_id
|
|
|
|
|
|
|
|
@@ -110,10 +121,8 @@ _c_ncclCommInitRank.restype = ctypes.c_int
|
|
|
_c_ncclCommInitRank.argtypes = [
|
|
|
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
|
|
|
]
|
|
|
-
|
|
|
-
|
|
|
-# enums
|
|
|
-class ncclDataType_t(ctypes.c_int):
|
|
|
+ncclDataType_t = ctypes.c_int
|
|
|
+class ncclDataTypeEnum:
|
|
|
ncclInt8 = 0
|
|
|
ncclChar = 0
|
|
|
ncclUint8 = 1
|
|
@@ -130,9 +139,8 @@ class ncclDataType_t(ctypes.c_int):
|
|
|
ncclDouble = 8
|
|
|
ncclBfloat16 = 9
|
|
|
ncclNumTypes = 10
|
|
|
-
|
|
|
@classmethod
|
|
|
- def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t':
|
|
|
+ def from_torch(cls, dtype: torch.dtype) -> int:
|
|
|
if dtype == torch.int8:
|
|
|
return cls.ncclInt8
|
|
|
if dtype == torch.uint8:
|
|
@@ -150,18 +158,16 @@ class ncclDataType_t(ctypes.c_int):
|
|
|
if dtype == torch.bfloat16:
|
|
|
return cls.ncclBfloat16
|
|
|
raise ValueError(f"Unsupported dtype: {dtype}")
|
|
|
-
|
|
|
-
|
|
|
-class ncclRedOp_t(ctypes.c_int):
|
|
|
+ncclRedOp_t = ctypes.c_int
|
|
|
+class ncclRedOpTypeEnum:
|
|
|
ncclSum = 0
|
|
|
ncclProd = 1
|
|
|
ncclMax = 2
|
|
|
ncclMin = 3
|
|
|
ncclAvg = 4
|
|
|
ncclNumOps = 5
|
|
|
-
|
|
|
@classmethod
|
|
|
- def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t':
|
|
|
+ def from_torch(cls, op: ReduceOp) -> int:
|
|
|
if op == ReduceOp.SUM:
|
|
|
return cls.ncclSum
|
|
|
if op == ReduceOp.PRODUCT:
|
|
@@ -173,8 +179,6 @@ class ncclRedOp_t(ctypes.c_int):
|
|
|
if op == ReduceOp.AVG:
|
|
|
return cls.ncclAvg
|
|
|
raise ValueError(f"Unsupported op: {op}")
|
|
|
-
|
|
|
-
|
|
|
# equivalent to c declaration:
|
|
|
# ncclResult_t ncclAllReduce(
|
|
|
# const void* sendbuff, void* recvbuff, size_t count,
|
|
@@ -184,10 +188,9 @@ class ncclRedOp_t(ctypes.c_int):
|
|
|
_c_ncclAllReduce = nccl.ncclAllReduce
|
|
|
_c_ncclAllReduce.restype = ctypes.c_int
|
|
|
_c_ncclAllReduce.argtypes = [
|
|
|
- ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t,
|
|
|
- ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p
|
|
|
+ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t,
|
|
|
+ ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p
|
|
|
]
|
|
|
-
|
|
|
# equivalent to c declaration:
|
|
|
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
|
|
|
_c_ncclCommDestroy = nccl.ncclCommDestroy
|
|
@@ -199,66 +202,75 @@ class NCCLCommunicator:
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
- backend=None,
|
|
|
- init_method=None,
|
|
|
- timeout=datetime.timedelta(seconds=10),
|
|
|
- world_size: int = -1,
|
|
|
- rank: int = -1,
|
|
|
- store=None,
|
|
|
- group_name: str = "",
|
|
|
- pg_options=None,
|
|
|
- local_rank: int = -1,
|
|
|
+ group: Optional[ProcessGroup] = None,
|
|
|
+ device: Optional[Union[int, str, torch.device]] = None,
|
|
|
):
|
|
|
- if not dist.is_initialized():
|
|
|
- backend = backend or "nccl"
|
|
|
- assert backend == 'nccl', (
|
|
|
- "only use nccl backend for starting the NCCL communicator")
|
|
|
- dist.init_process_group(backend=backend,
|
|
|
- init_method=init_method,
|
|
|
- timeout=timeout,
|
|
|
- world_size=world_size,
|
|
|
- rank=rank,
|
|
|
- store=store,
|
|
|
- group_name=group_name,
|
|
|
- pg_options=pg_options)
|
|
|
- self.rank = dist.get_rank()
|
|
|
- self.world_size = dist.get_world_size()
|
|
|
- if local_rank == -1:
|
|
|
- local_rank = self.rank
|
|
|
- self.local_rank = local_rank
|
|
|
- # don't use these args, as they can be -1
|
|
|
- # use `self.rank`, `self.local_rank` and `self.world_size` instead
|
|
|
- del world_size, rank, local_rank
|
|
|
- torch.cuda.set_device(self.local_rank)
|
|
|
+ assert dist.is_initialized()
|
|
|
+ group = get_cpu_world_group() if group is None else group
|
|
|
+ assert dist.get_backend(group) != dist.Backend.NCCL, (
|
|
|
+ "NCCLCommunicator should be attached to a non-NCCL group.")
|
|
|
+ self.group = group
|
|
|
+ self.rank = dist.get_rank(group)
|
|
|
+ self.world_size = dist.get_world_size(group)
|
|
|
if self.rank == 0:
|
|
|
self.unique_id = ncclGetUniqueId()
|
|
|
else:
|
|
|
self.unique_id = NcclUniqueId()
|
|
|
- tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda(
|
|
|
- self.local_rank)
|
|
|
- dist.broadcast(tensor, src=0)
|
|
|
- byte_list = tensor.cpu().tolist()
|
|
|
+ tensor = torch.ByteTensor(list(self.unique_id.internal))
|
|
|
+ dist.broadcast(tensor, src=0, group=group)
|
|
|
+ byte_list = tensor.tolist()
|
|
|
for i, byte in enumerate(byte_list):
|
|
|
self.unique_id.internal[i] = byte
|
|
|
self.comm = ctypes.c_void_p()
|
|
|
- result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
|
|
|
- self.unique_id, self.rank)
|
|
|
- assert result == 0
|
|
|
- self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}")
|
|
|
+ # result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
|
|
|
+ # self.unique_id, self.rank)
|
|
|
+ # assert result == 0
|
|
|
+ # self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}")
|
|
|
+ if device is None:
|
|
|
+ local_rank = get_local_rank()
|
|
|
+ device = torch.device(f"cuda:{local_rank}")
|
|
|
+ elif isinstance(device, int):
|
|
|
+ device = torch.device(f"cuda:{device}")
|
|
|
+ elif isinstance(device, str):
|
|
|
+ device = torch.device(device)
|
|
|
+ # now the `device` object is a `torch.device` object
|
|
|
+ assert isinstance(device, torch.device)
|
|
|
+ self.device = device
|
|
|
+ current_device = torch.cuda.current_device()
|
|
|
+ try:
|
|
|
+ torch.cuda.set_device(self.device)
|
|
|
+ NCCL_CHECK(_c_ncclCommInitRank(ctypes.byref(self.comm),
|
|
|
+ self.world_size, self.unique_id,
|
|
|
+ self.rank))
|
|
|
+ self.stream = torch.cuda.Stream()
|
|
|
+ finally:
|
|
|
+ torch.cuda.set_device(current_device)
|
|
|
|
|
|
def all_reduce(self,
|
|
|
tensor: torch.Tensor,
|
|
|
op: ReduceOp = ReduceOp.SUM,
|
|
|
stream=None):
|
|
|
+ # nccl communicator created on a specific device will only work
|
|
|
+ # on tensors on the same device, otherwise it'll cause
|
|
|
+ # illegal memory access
|
|
|
+ assert tensor.device == self.device, (
|
|
|
+ f"tensor.device={tensor.device} should be {self.device}")
|
|
|
if stream is None:
|
|
|
stream = self.stream
|
|
|
- result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
|
|
|
- ctypes.c_void_p(tensor.data_ptr()),
|
|
|
- tensor.numel(),
|
|
|
- ncclDataType_t.from_torch(tensor.dtype),
|
|
|
- ncclRedOp_t.from_torch(op), self.comm,
|
|
|
- ctypes.c_void_p(stream.cuda_stream))
|
|
|
- assert result == 0
|
|
|
+ # result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
|
|
|
+ # ctypes.c_void_p(tensor.data_ptr()),
|
|
|
+ # tensor.numel(),
|
|
|
+ # ncclDataType_t.from_torch(tensor.dtype),
|
|
|
+ # ncclRedOp_t.from_torch(op), self.comm,
|
|
|
+ # ctypes.c_void_p(stream.cuda_stream))
|
|
|
+ NCCL_CHECK(
|
|
|
+ _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
|
|
|
+ ctypes.c_void_p(tensor.data_ptr()),
|
|
|
+ tensor.numel(),
|
|
|
+ ncclDataTypeEnum.from_torch(tensor.dtype),
|
|
|
+ ncclRedOpTypeEnum.from_torch(op), self.comm,
|
|
|
+ ctypes.c_void_p(stream.cuda_stream))
|
|
|
+ )
|
|
|
|
|
|
def __del__(self):
|
|
|
# `dist` module might have been already destroyed
|