123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275 |
- # This file is a pure Python wrapper for the NCCL library.
- # The main purpose is to use NCCL combined with CUDA graph.
- # Before writing this script, we tried the following approach:
- # 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
- # often gets stuck when initializing the NCCL communicator.
- # 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
- # contains many other potential cuda APIs, that are not allowed during
- # capturing the CUDA graph. For further details, please check
- # https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
- #
- # Another rejected idea is to write a C/C++ binding for NCCL. It is usually
- # doable, but we often encounter issues related with nccl versions, and need
- # to switch between different versions of NCCL. See
- # https://github.com/NVIDIA/nccl/issues/1234 for more details.
- # A C/C++ binding is not flexible enough to handle this. It requires
- # recompilation of the code every time we want to switch between different
- # versions. This current implementation, with a **pure** Python wrapper, is
- # more flexible. We can easily switch between different versions of NCCL by
- # changing the environment variable `APHRODITE_NCCL_SO_PATH`, or the `so_file`
- # variable in the code.
- import ctypes
- import platform
- from dataclasses import dataclass
- from typing import Any, Dict, List, Optional
- import torch
- from loguru import logger
- from torch.distributed import ReduceOp
- from aphrodite.common.utils import find_nccl_library
- # === export types and functions from nccl to Python ===
- # for the original nccl definition, please check
- # https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
- ncclResult_t = ctypes.c_int
- ncclComm_t = ctypes.c_void_p
- class ncclUniqueId(ctypes.Structure):
- _fields_ = [("internal", ctypes.c_byte * 128)]
- cudaStream_t = ctypes.c_void_p
- buffer_type = ctypes.c_void_p
- ncclDataType_t = ctypes.c_int
- class ncclDataTypeEnum:
- ncclInt8 = 0
- ncclChar = 0
- ncclUint8 = 1
- ncclInt32 = 2
- ncclInt = 2
- ncclUint32 = 3
- ncclInt64 = 4
- ncclUint64 = 5
- ncclFloat16 = 6
- ncclHalf = 6
- ncclFloat32 = 7
- ncclFloat = 7
- ncclFloat64 = 8
- ncclDouble = 8
- ncclBfloat16 = 9
- ncclNumTypes = 10
- @classmethod
- def from_torch(cls, dtype: torch.dtype) -> int:
- if dtype == torch.int8:
- return cls.ncclInt8
- if dtype == torch.uint8:
- return cls.ncclUint8
- if dtype == torch.int32:
- return cls.ncclInt32
- if dtype == torch.int64:
- return cls.ncclInt64
- if dtype == torch.float16:
- return cls.ncclFloat16
- if dtype == torch.float32:
- return cls.ncclFloat32
- if dtype == torch.float64:
- return cls.ncclFloat64
- if dtype == torch.bfloat16:
- return cls.ncclBfloat16
- raise ValueError(f"Unsupported dtype: {dtype}")
- ncclRedOp_t = ctypes.c_int
- class ncclRedOpTypeEnum:
- ncclSum = 0
- ncclProd = 1
- ncclMax = 2
- ncclMin = 3
- ncclAvg = 4
- ncclNumOps = 5
- @classmethod
- def from_torch(cls, op: ReduceOp) -> int:
- if op == ReduceOp.SUM:
- return cls.ncclSum
- if op == ReduceOp.PRODUCT:
- return cls.ncclProd
- if op == ReduceOp.MAX:
- return cls.ncclMax
- if op == ReduceOp.MIN:
- return cls.ncclMin
- if op == ReduceOp.AVG:
- return cls.ncclAvg
- raise ValueError(f"Unsupported op: {op}")
- @dataclass
- class Function:
- name: str
- restype: Any
- argtypes: List[Any]
- class NCCLLibrary:
- exported_functions = [
- # const char* ncclGetErrorString(ncclResult_t result)
- Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
- # ncclResult_t ncclGetVersion(int *version);
- Function("ncclGetVersion", ncclResult_t,
- [ctypes.POINTER(ctypes.c_int)]),
- # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
- Function("ncclGetUniqueId", ncclResult_t,
- [ctypes.POINTER(ncclUniqueId)]),
- # ncclResult_t ncclCommInitRank(
- # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
- # note that ncclComm_t is a pointer type, so the first argument
- # is a pointer to a pointer
- Function("ncclCommInitRank", ncclResult_t, [
- ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId,
- ctypes.c_int
- ]),
- # ncclResult_t ncclAllReduce(
- # const void* sendbuff, void* recvbuff, size_t count,
- # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
- # cudaStream_t stream);
- # note that cudaStream_t is a pointer type, so the last argument
- # is a pointer
- Function("ncclAllReduce", ncclResult_t, [
- buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
- ncclRedOp_t, ncclComm_t, cudaStream_t
- ]),
- # ncclResult_t ncclSend(
- # const void* sendbuff, size_t count, ncclDataType_t datatype,
- # int dest, ncclComm_t comm, cudaStream_t stream);
- Function("ncclSend", ncclResult_t, [
- buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
- ncclComm_t, cudaStream_t
- ]),
- # ncclResult_t ncclRecv(
- # void* recvbuff, size_t count, ncclDataType_t datatype,
- # int src, ncclComm_t comm, cudaStream_t stream);
- Function("ncclRecv", ncclResult_t, [
- buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
- ncclComm_t, cudaStream_t
- ]),
- # be cautious! this is a collective call, it will block until all
- # processes in the communicator have called this function.
- # because Python object destruction can happen in random order,
- # it is better not to call it at all.
- # ncclResult_t ncclCommDestroy(ncclComm_t comm);
- Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
- ]
- # class attribute to store the mapping from the path to the library
- # to avoid loading the same library multiple times
- path_to_library_cache: Dict[str, Any] = {}
- # class attribute to store the mapping from library path
- # to the corresponding dictionary
- path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
- def __init__(self, so_file: Optional[str] = None):
- so_file = so_file or find_nccl_library()
- try:
- if so_file not in NCCLLibrary.path_to_dict_mapping:
- lib = ctypes.CDLL(so_file)
- NCCLLibrary.path_to_library_cache[so_file] = lib
- self.lib = NCCLLibrary.path_to_library_cache[so_file]
- except Exception as e:
- logger.error(
- f"Failed to load NCCL library from {so_file} ."
- "It is expected if you are not running on NVIDIA/AMD GPUs."
- "Otherwise, the nccl library might not exist, be corrupted "
- "or it does not support the current platform "
- f"{platform.platform()}. If you already have the library, "
- "please set the environment variable APHRODITE_NCCL_SO_PATH"
- " to point to the correct nccl library path.")
- raise e
- if so_file not in NCCLLibrary.path_to_dict_mapping:
- _funcs = {}
- for func in NCCLLibrary.exported_functions:
- f = getattr(self.lib, func.name)
- f.restype = func.restype
- f.argtypes = func.argtypes
- _funcs[func.name] = f
- NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
- self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]
- def ncclGetErrorString(self, result: ncclResult_t) -> str:
- return self._funcs["ncclGetErrorString"](result).decode("utf-8")
- def NCCL_CHECK(self, result: ncclResult_t) -> None:
- if result != 0:
- error_str = self.ncclGetErrorString(result)
- raise RuntimeError(f"NCCL error: {error_str}")
- def ncclGetVersion(self) -> str:
- version = ctypes.c_int()
- self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
- version_str = str(version.value)
- # something like 21903 --> "2.19.3"
- major = version_str[0].lstrip("0")
- minor = version_str[1:3].lstrip("0")
- patch = version_str[3:].lstrip("0")
- return f"{major}.{minor}.{patch}"
- def ncclGetUniqueId(self) -> ncclUniqueId:
- unique_id = ncclUniqueId()
- self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](
- ctypes.byref(unique_id)))
- return unique_id
- def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
- rank: int) -> ncclComm_t:
- comm = ncclComm_t()
- self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm),
- world_size, unique_id,
- rank))
- return comm
- def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
- count: int, datatype: int, op: int, comm: ncclComm_t,
- stream: cudaStream_t) -> None:
- # `datatype` actually should be `ncclDataType_t`
- # and `op` should be `ncclRedOp_t`
- # both are aliases of `ctypes.c_int`
- # when we pass int to a function, it will be converted to `ctypes.c_int`
- # by ctypes automatically
- self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count,
- datatype, op, comm,
- stream))
- def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
- dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
- self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype,
- dest, comm, stream))
- def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int,
- src: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
- self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src,
- comm, stream))
- def ncclCommDestroy(self, comm: ncclComm_t) -> None:
- self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
- __all__ = [
- "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
- "ncclComm_t", "cudaStream_t", "buffer_type"
- ]
|