123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287 |
- # This file is a pure Python wrapper for the NCCL library.
- # The main purpose is to use NCCL combined with CUDA graph.
- # Before writing this script, we tried the following approach:
- # 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
- # often gets stuck when initializing the NCCL communicator.
- # 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
- # contains many other potential cuda APIs, that are not allowed during
- # capturing the CUDA graph. For further details, please check
- # https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
- #
- # Another rejected idea is to write a C/C++ binding for NCCL. It is usually
- # doable, but we often encounter issues related with nccl versions, and need
- # to switch between different versions of NCCL. See
- # https://github.com/NVIDIA/nccl/issues/1234 for more details.
- # A C/C++ binding is not flexible enough to handle this. It requires
- # recompilation of the code every time we want to switch between different
- # versions. This current implementation, with a **pure** Python wrapper, is
- # more flexible. We can easily switch between different versions of NCCL by
- # changing the environment variable `APHRODITE_NCCL_SO_PATH`, or the `so_file`
- # variable in the code.
- import ctypes
- import logging
- import os
- from typing import Optional, Union
- # ===================== import region =====================
- import torch
- import torch.distributed as dist
- from torch.distributed import ProcessGroup, ReduceOp
- from aphrodite.distributed.parallel_state import (get_cpu_world_group,
- get_local_rank)
- logger = logging.getLogger(__name__)
- so_file = os.environ.get("APHRODITE_NCCL_SO_PATH", "")
- # manually load the nccl library
- if so_file:
- logger.info(
- f"Loading nccl from env variable APHRODITE_NCCL_SO_PATH={so_file}")
- else:
- if torch.version.cuda is not None:
- so_file = "libnccl.so.2"
- elif torch.version.hip is not None:
- so_file = "librccl.so.1"
- else:
- raise ValueError("NCCL only supports CUDA and ROCm backends.")
- logger.debug(f"Loading nccl from library {so_file}")
- try:
- nccl = ctypes.CDLL(so_file)
- except Exception as e:
- logger.error(
- f"Failed to load NCCL library from {so_file} ."
- "It is expected if you are not running on NVIDIA/AMD GPUs."
- "Otherwise please set the environment variable APHRODITE_NCCL_SO_PATH"
- " to point to the correct nccl library path. You can install nccl"
- " with `conda install nccl` or `pip install nvidia-nccl-cu12`")
- raise e
- # === export types and functions from nccl to Python ===
- # for the original nccl definition, please check
- # https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
- ncclResult_t = ctypes.c_int
- _c_ncclGetErrorString = nccl.ncclGetErrorString
- _c_ncclGetErrorString.restype = ctypes.c_char_p
- _c_ncclGetErrorString.argtypes = [ncclResult_t]
- def NCCL_CHECK(result: ncclResult_t) -> None:
- if result != 0:
- error_str = _c_ncclGetErrorString(result).decode("utf-8")
- raise RuntimeError(f"NCCL error: {error_str}")
- # equivalent to c declaration:
- # ncclResult_t ncclGetVersion(int *version);
- _c_ncclGetVersion = nccl.ncclGetVersion
- _c_ncclGetVersion.restype = ctypes.c_int
- _c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
- def ncclGetVersion() -> str:
- version = ctypes.c_int()
- NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version)))
- # something like 21903 --> "2.19.3"
- version_str = str(version.value)
- major = version_str[0].lstrip("0")
- minor = version_str[1:3].lstrip("0")
- patch = version_str[3:].lstrip("0")
- return f"{major}.{minor}.{patch}"
- class NcclUniqueId(ctypes.Structure):
- _fields_ = [("internal", ctypes.c_byte * 128)]
- # equivalent to c declaration:
- # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
- _c_ncclGetUniqueId = nccl.ncclGetUniqueId
- _c_ncclGetUniqueId.restype = ctypes.c_int
- _c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
- def ncclGetUniqueId() -> NcclUniqueId:
- unique_id = NcclUniqueId()
- NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id)))
- return unique_id
- # equivalent to c declaration:
- # ncclResult_t ncclCommInitRank(
- # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
- # note that ncclComm_t is a pointer type, so the first argument
- # is a pointer to a pointer
- _c_ncclCommInitRank = nccl.ncclCommInitRank
- _c_ncclCommInitRank.restype = ctypes.c_int
- _c_ncclCommInitRank.argtypes = [
- ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
- ]
- ncclDataType_t = ctypes.c_int
- class ncclDataTypeEnum:
- ncclInt8 = 0
- ncclChar = 0
- ncclUint8 = 1
- ncclInt32 = 2
- ncclInt = 2
- ncclUint32 = 3
- ncclInt64 = 4
- ncclUint64 = 5
- ncclFloat16 = 6
- ncclHalf = 6
- ncclFloat32 = 7
- ncclFloat = 7
- ncclFloat64 = 8
- ncclDouble = 8
- ncclBfloat16 = 9
- ncclNumTypes = 10
- @classmethod
- def from_torch(cls, dtype: torch.dtype) -> int:
- if dtype == torch.int8:
- return cls.ncclInt8
- if dtype == torch.uint8:
- return cls.ncclUint8
- if dtype == torch.int32:
- return cls.ncclInt32
- if dtype == torch.int64:
- return cls.ncclInt64
- if dtype == torch.float16:
- return cls.ncclFloat16
- if dtype == torch.float32:
- return cls.ncclFloat32
- if dtype == torch.float64:
- return cls.ncclFloat64
- if dtype == torch.bfloat16:
- return cls.ncclBfloat16
- raise ValueError(f"Unsupported dtype: {dtype}")
- ncclRedOp_t = ctypes.c_int
- class ncclRedOpTypeEnum:
- ncclSum = 0
- ncclProd = 1
- ncclMax = 2
- ncclMin = 3
- ncclAvg = 4
- ncclNumOps = 5
- @classmethod
- def from_torch(cls, op: ReduceOp) -> int:
- if op == ReduceOp.SUM:
- return cls.ncclSum
- if op == ReduceOp.PRODUCT:
- return cls.ncclProd
- if op == ReduceOp.MAX:
- return cls.ncclMax
- if op == ReduceOp.MIN:
- return cls.ncclMin
- if op == ReduceOp.AVG:
- return cls.ncclAvg
- raise ValueError(f"Unsupported op: {op}")
- # equivalent to c declaration:
- # ncclResult_t ncclAllReduce(
- # const void* sendbuff, void* recvbuff, size_t count,
- # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
- # udaStream_t stream);
- # note that cudaStream_t is a pointer type, so the last argument is a pointer
- _c_ncclAllReduce = nccl.ncclAllReduce
- _c_ncclAllReduce.restype = ctypes.c_int
- _c_ncclAllReduce.argtypes = [
- ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t,
- ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p
- ]
- # equivalent to c declaration:
- # ncclResult_t ncclCommDestroy(ncclComm_t comm);
- _c_ncclCommDestroy = nccl.ncclCommDestroy
- _c_ncclCommDestroy.restype = ctypes.c_int
- _c_ncclCommDestroy.argtypes = [ctypes.c_void_p]
- class NCCLCommunicator:
- def __init__(
- self,
- group: Optional[ProcessGroup] = None,
- device: Optional[Union[int, str, torch.device]] = None,
- ):
- assert dist.is_initialized()
- group = get_cpu_world_group() if group is None else group
- assert dist.get_backend(group) != dist.Backend.NCCL, (
- "NCCLCommunicator should be attached to a non-NCCL group.")
- self.group = group
- self.rank = dist.get_rank(group)
- self.world_size = dist.get_world_size(group)
- if self.rank == 0:
- self.unique_id = ncclGetUniqueId()
- else:
- self.unique_id = NcclUniqueId()
- tensor = torch.ByteTensor(list(self.unique_id.internal))
- dist.broadcast(tensor, src=0, group=group)
- byte_list = tensor.tolist()
- for i, byte in enumerate(byte_list):
- self.unique_id.internal[i] = byte
- self.comm = ctypes.c_void_p()
- # result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
- # self.unique_id, self.rank)
- # assert result == 0
- # self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}")
- if device is None:
- local_rank = get_local_rank()
- device = torch.device(f"cuda:{local_rank}")
- elif isinstance(device, int):
- device = torch.device(f"cuda:{device}")
- elif isinstance(device, str):
- device = torch.device(device)
- # now the `device` object is a `torch.device` object
- assert isinstance(device, torch.device)
- self.device = device
- with torch.cuda.device(device):
- NCCL_CHECK(
- _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
- self.unique_id, self.rank))
- self.stream = torch.cuda.Stream()
- def all_reduce(self,
- tensor: torch.Tensor,
- op: ReduceOp = ReduceOp.SUM,
- stream=None):
- # nccl communicator created on a specific device will only work
- # on tensors on the same device, otherwise it'll cause
- # illegal memory access
- assert tensor.device == self.device, (
- f"tensor.device={tensor.device} should be {self.device}")
- if stream is None:
- stream = self.stream
- # result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
- # ctypes.c_void_p(tensor.data_ptr()),
- # tensor.numel(),
- # ncclDataType_t.from_torch(tensor.dtype),
- # ncclRedOp_t.from_torch(op), self.comm,
- # ctypes.c_void_p(stream.cuda_stream))
- NCCL_CHECK(
- _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
- ctypes.c_void_p(tensor.data_ptr()),
- tensor.numel(),
- ncclDataTypeEnum.from_torch(tensor.dtype),
- ncclRedOpTypeEnum.from_torch(op), self.comm,
- ctypes.c_void_p(stream.cuda_stream)))
- def __del__(self):
- # `dist` module might have been already destroyed
- if hasattr(dist, 'destroy_process_group'):
- dist.destroy_process_group()
- # function might have been already destroyed
- if _c_ncclCommDestroy is not None:
- _c_ncclCommDestroy(self.comm)
|