pynccl.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. # This file is a pure Python wrapper for the NCCL library.
  2. # The main purpose is to use NCCL combined with CUDA graph.
  3. # Before writing this script, we tried the following approach:
  4. # 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
  5. # often gets stuck when initializing the NCCL communicator.
  6. # 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
  7. # contains many other potential cuda APIs, that are not allowed during
  8. # capturing the CUDA graph. For further details, please check
  9. # https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
  10. #
  11. # Another rejected idea is to write a C/C++ binding for NCCL. It is usually
  12. # doable, but we often encounter issues related with nccl versions, and need
  13. # to switch between different versions of NCCL. See
  14. # https://github.com/NVIDIA/nccl/issues/1234 for more details.
  15. # A C/C++ binding is not flexible enough to handle this. It requires
  16. # recompilation of the code every time we want to switch between different
  17. # versions. This current implementation, with a **pure** Python wrapper, is
  18. # more flexible. We can easily switch between different versions of NCCL by
  19. # changing the environment variable `APHRODITE_NCCL_SO_PATH`, or the `so_file`
  20. # variable in the code.
  21. import ctypes
  22. import logging
  23. import os
  24. from typing import Optional, Union
  25. # ===================== import region =====================
  26. import torch
  27. import torch.distributed as dist
  28. from torch.distributed import ProcessGroup, ReduceOp
  29. from aphrodite.distributed.parallel_state import (get_cpu_world_group,
  30. get_local_rank)
  31. logger = logging.getLogger(__name__)
  32. so_file = os.environ.get("APHRODITE_NCCL_SO_PATH", "")
  33. # manually load the nccl library
  34. if so_file:
  35. logger.info(
  36. f"Loading nccl from env variable APHRODITE_NCCL_SO_PATH={so_file}")
  37. else:
  38. if torch.version.cuda is not None:
  39. so_file = "libnccl.so.2"
  40. elif torch.version.hip is not None:
  41. so_file = "librccl.so.1"
  42. else:
  43. raise ValueError("NCCL only supports CUDA and ROCm backends.")
  44. logger.debug(f"Loading nccl from library {so_file}")
  45. try:
  46. nccl = ctypes.CDLL(so_file)
  47. except Exception as e:
  48. logger.error(
  49. f"Failed to load NCCL library from {so_file} ."
  50. "It is expected if you are not running on NVIDIA/AMD GPUs."
  51. "Otherwise please set the environment variable APHRODITE_NCCL_SO_PATH"
  52. " to point to the correct nccl library path. You can install nccl"
  53. " with `conda install nccl` or `pip install nvidia-nccl-cu12`")
  54. raise e
  55. # === export types and functions from nccl to Python ===
  56. # for the original nccl definition, please check
  57. # https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
  58. ncclResult_t = ctypes.c_int
  59. _c_ncclGetErrorString = nccl.ncclGetErrorString
  60. _c_ncclGetErrorString.restype = ctypes.c_char_p
  61. _c_ncclGetErrorString.argtypes = [ncclResult_t]
  62. def NCCL_CHECK(result: ncclResult_t) -> None:
  63. if result != 0:
  64. error_str = _c_ncclGetErrorString(result).decode("utf-8")
  65. raise RuntimeError(f"NCCL error: {error_str}")
  66. # equivalent to c declaration:
  67. # ncclResult_t ncclGetVersion(int *version);
  68. _c_ncclGetVersion = nccl.ncclGetVersion
  69. _c_ncclGetVersion.restype = ctypes.c_int
  70. _c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
  71. def ncclGetVersion() -> str:
  72. version = ctypes.c_int()
  73. NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version)))
  74. # something like 21903 --> "2.19.3"
  75. version_str = str(version.value)
  76. major = version_str[0].lstrip("0")
  77. minor = version_str[1:3].lstrip("0")
  78. patch = version_str[3:].lstrip("0")
  79. return f"{major}.{minor}.{patch}"
  80. class NcclUniqueId(ctypes.Structure):
  81. _fields_ = [("internal", ctypes.c_byte * 128)]
  82. # equivalent to c declaration:
  83. # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
  84. _c_ncclGetUniqueId = nccl.ncclGetUniqueId
  85. _c_ncclGetUniqueId.restype = ctypes.c_int
  86. _c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
  87. def ncclGetUniqueId() -> NcclUniqueId:
  88. unique_id = NcclUniqueId()
  89. NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id)))
  90. return unique_id
  91. # equivalent to c declaration:
  92. # ncclResult_t ncclCommInitRank(
  93. # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
  94. # note that ncclComm_t is a pointer type, so the first argument
  95. # is a pointer to a pointer
  96. _c_ncclCommInitRank = nccl.ncclCommInitRank
  97. _c_ncclCommInitRank.restype = ctypes.c_int
  98. _c_ncclCommInitRank.argtypes = [
  99. ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
  100. ]
  101. ncclDataType_t = ctypes.c_int
  102. class ncclDataTypeEnum:
  103. ncclInt8 = 0
  104. ncclChar = 0
  105. ncclUint8 = 1
  106. ncclInt32 = 2
  107. ncclInt = 2
  108. ncclUint32 = 3
  109. ncclInt64 = 4
  110. ncclUint64 = 5
  111. ncclFloat16 = 6
  112. ncclHalf = 6
  113. ncclFloat32 = 7
  114. ncclFloat = 7
  115. ncclFloat64 = 8
  116. ncclDouble = 8
  117. ncclBfloat16 = 9
  118. ncclNumTypes = 10
  119. @classmethod
  120. def from_torch(cls, dtype: torch.dtype) -> int:
  121. if dtype == torch.int8:
  122. return cls.ncclInt8
  123. if dtype == torch.uint8:
  124. return cls.ncclUint8
  125. if dtype == torch.int32:
  126. return cls.ncclInt32
  127. if dtype == torch.int64:
  128. return cls.ncclInt64
  129. if dtype == torch.float16:
  130. return cls.ncclFloat16
  131. if dtype == torch.float32:
  132. return cls.ncclFloat32
  133. if dtype == torch.float64:
  134. return cls.ncclFloat64
  135. if dtype == torch.bfloat16:
  136. return cls.ncclBfloat16
  137. raise ValueError(f"Unsupported dtype: {dtype}")
  138. ncclRedOp_t = ctypes.c_int
  139. class ncclRedOpTypeEnum:
  140. ncclSum = 0
  141. ncclProd = 1
  142. ncclMax = 2
  143. ncclMin = 3
  144. ncclAvg = 4
  145. ncclNumOps = 5
  146. @classmethod
  147. def from_torch(cls, op: ReduceOp) -> int:
  148. if op == ReduceOp.SUM:
  149. return cls.ncclSum
  150. if op == ReduceOp.PRODUCT:
  151. return cls.ncclProd
  152. if op == ReduceOp.MAX:
  153. return cls.ncclMax
  154. if op == ReduceOp.MIN:
  155. return cls.ncclMin
  156. if op == ReduceOp.AVG:
  157. return cls.ncclAvg
  158. raise ValueError(f"Unsupported op: {op}")
  159. # equivalent to c declaration:
  160. # ncclResult_t ncclAllReduce(
  161. # const void* sendbuff, void* recvbuff, size_t count,
  162. # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
  163. # udaStream_t stream);
  164. # note that cudaStream_t is a pointer type, so the last argument is a pointer
  165. _c_ncclAllReduce = nccl.ncclAllReduce
  166. _c_ncclAllReduce.restype = ctypes.c_int
  167. _c_ncclAllReduce.argtypes = [
  168. ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t,
  169. ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p
  170. ]
  171. # equivalent to c declaration:
  172. # ncclResult_t ncclCommDestroy(ncclComm_t comm);
  173. _c_ncclCommDestroy = nccl.ncclCommDestroy
  174. _c_ncclCommDestroy.restype = ctypes.c_int
  175. _c_ncclCommDestroy.argtypes = [ctypes.c_void_p]
  176. class NCCLCommunicator:
  177. def __init__(
  178. self,
  179. group: Optional[ProcessGroup] = None,
  180. device: Optional[Union[int, str, torch.device]] = None,
  181. ):
  182. assert dist.is_initialized()
  183. group = get_cpu_world_group() if group is None else group
  184. assert dist.get_backend(group) != dist.Backend.NCCL, (
  185. "NCCLCommunicator should be attached to a non-NCCL group.")
  186. self.group = group
  187. self.rank = dist.get_rank(group)
  188. self.world_size = dist.get_world_size(group)
  189. if self.rank == 0:
  190. self.unique_id = ncclGetUniqueId()
  191. else:
  192. self.unique_id = NcclUniqueId()
  193. tensor = torch.ByteTensor(list(self.unique_id.internal))
  194. dist.broadcast(tensor, src=0, group=group)
  195. byte_list = tensor.tolist()
  196. for i, byte in enumerate(byte_list):
  197. self.unique_id.internal[i] = byte
  198. self.comm = ctypes.c_void_p()
  199. # result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
  200. # self.unique_id, self.rank)
  201. # assert result == 0
  202. # self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}")
  203. if device is None:
  204. local_rank = get_local_rank()
  205. device = torch.device(f"cuda:{local_rank}")
  206. elif isinstance(device, int):
  207. device = torch.device(f"cuda:{device}")
  208. elif isinstance(device, str):
  209. device = torch.device(device)
  210. # now the `device` object is a `torch.device` object
  211. assert isinstance(device, torch.device)
  212. self.device = device
  213. with torch.cuda.device(device):
  214. NCCL_CHECK(
  215. _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
  216. self.unique_id, self.rank))
  217. self.stream = torch.cuda.Stream()
  218. def all_reduce(self,
  219. tensor: torch.Tensor,
  220. op: ReduceOp = ReduceOp.SUM,
  221. stream=None):
  222. # nccl communicator created on a specific device will only work
  223. # on tensors on the same device, otherwise it'll cause
  224. # illegal memory access
  225. assert tensor.device == self.device, (
  226. f"tensor.device={tensor.device} should be {self.device}")
  227. if stream is None:
  228. stream = self.stream
  229. # result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
  230. # ctypes.c_void_p(tensor.data_ptr()),
  231. # tensor.numel(),
  232. # ncclDataType_t.from_torch(tensor.dtype),
  233. # ncclRedOp_t.from_torch(op), self.comm,
  234. # ctypes.c_void_p(stream.cuda_stream))
  235. NCCL_CHECK(
  236. _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
  237. ctypes.c_void_p(tensor.data_ptr()),
  238. tensor.numel(),
  239. ncclDataTypeEnum.from_torch(tensor.dtype),
  240. ncclRedOpTypeEnum.from_torch(op), self.comm,
  241. ctypes.c_void_p(stream.cuda_stream)))
  242. def __del__(self):
  243. # `dist` module might have been already destroyed
  244. if hasattr(dist, 'destroy_process_group'):
  245. dist.destroy_process_group()
  246. # function might have been already destroyed
  247. if _c_ncclCommDestroy is not None:
  248. _c_ncclCommDestroy(self.comm)