pynccl.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # This file is a pure Python wrapper for the NCCL library.
  2. # The main purpose is to use NCCL combined with CUDA graph.
  3. # Before writing this script, we tried the following approach:
  4. # 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
  5. # often gets stuck when initializing the NCCL communicator.
  6. # 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
  7. # contains many other potential cuda APIs, that are not allowed during
  8. # capturing the CUDA graph. For further details, please check
  9. # https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
  10. #
  11. # Another rejected idea is to write a C/C++ binding for NCCL. It is usually
  12. # doable, but we often encounter issues related with nccl versions, and need
  13. # to switch between different versions of NCCL. See
  14. # https://github.com/NVIDIA/nccl/issues/1234 for more details.
  15. # A C/C++ binding is not flexible enough to handle this. It requires
  16. # recompilation of the code every time we want to switch between different
  17. # versions. This current implementation, with a **pure** Python wrapper, is
  18. # more flexible. We can easily switch between different versions of NCCL by
  19. # changing the environment variable `APHRODITE_NCCL_SO_PATH`, or the `so_file`
  20. # variable in the code.
  21. import ctypes
  22. import datetime
  23. import logging
  24. import os
  25. # ===================== import region =====================
  26. import torch
  27. import torch.distributed as dist
  28. from torch.distributed import ReduceOp
  29. logger = logging.getLogger(__name__)
  30. so_file = os.environ.get("APHRODITE_NCCL_SO_PATH", "")
  31. # manually load the nccl library
  32. if so_file:
  33. logger.info(
  34. f"Loading nccl from env variable APHRODITE_NCCL_SO_PATH={so_file}")
  35. else:
  36. if torch.version.cuda is not None:
  37. so_file = "libnccl.so.2"
  38. elif torch.version.hip is not None:
  39. so_file = "librccl.so.1"
  40. else:
  41. raise ValueError("NCCL only supports CUDA and ROCm backends.")
  42. logger.debug(f"Loading nccl from library {so_file}")
  43. try:
  44. nccl = ctypes.CDLL(so_file)
  45. except Exception as e:
  46. logger.error(
  47. f"Failed to load NCCL library from {so_file} ."
  48. "It is expected if you are not running on NVIDIA/AMD GPUs."
  49. "Otherwise please set the environment variable APHRODITE_NCCL_SO_PATH"
  50. " to point to the correct nccl library path. You can install nccl"
  51. " with `conda install nccl` or `pip install nvidia-nccl-cu12`")
  52. raise e
  53. # === export types and functions from nccl to Python ===
  54. # for the original nccl definition, please check
  55. # https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
  56. ncclResult_t = ctypes.c_int
  57. # equivalent to c declaration:
  58. # ncclResult_t ncclGetVersion(int *version);
  59. _c_ncclGetVersion = nccl.ncclGetVersion
  60. _c_ncclGetVersion.restype = ctypes.c_int
  61. _c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
  62. def ncclGetVersion() -> str:
  63. version = ctypes.c_int()
  64. result = _c_ncclGetVersion(ctypes.byref(version))
  65. assert result == 0
  66. # something like 21903 --> "2.19.3"
  67. version_str = str(version.value)
  68. major = version_str[0].lstrip("0")
  69. minor = version_str[1:3].lstrip("0")
  70. patch = version_str[3:].lstrip("0")
  71. return f"{major}.{minor}.{patch}"
  72. class NcclUniqueId(ctypes.Structure):
  73. _fields_ = [("internal", ctypes.c_byte * 128)]
  74. # equivalent to c declaration:
  75. # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
  76. _c_ncclGetUniqueId = nccl.ncclGetUniqueId
  77. _c_ncclGetUniqueId.restype = ctypes.c_int
  78. _c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
  79. def ncclGetUniqueId() -> NcclUniqueId:
  80. unique_id = NcclUniqueId()
  81. result = _c_ncclGetUniqueId(ctypes.byref(unique_id))
  82. assert result == 0
  83. return unique_id
  84. # equivalent to c declaration:
  85. # ncclResult_t ncclCommInitRank(
  86. # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
  87. # note that ncclComm_t is a pointer type, so the first argument
  88. # is a pointer to a pointer
  89. _c_ncclCommInitRank = nccl.ncclCommInitRank
  90. _c_ncclCommInitRank.restype = ctypes.c_int
  91. _c_ncclCommInitRank.argtypes = [
  92. ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
  93. ]
  94. # enums
  95. class ncclDataType_t(ctypes.c_int):
  96. ncclInt8 = 0
  97. ncclChar = 0
  98. ncclUint8 = 1
  99. ncclInt32 = 2
  100. ncclInt = 2
  101. ncclUint32 = 3
  102. ncclInt64 = 4
  103. ncclUint64 = 5
  104. ncclFloat16 = 6
  105. ncclHalf = 6
  106. ncclFloat32 = 7
  107. ncclFloat = 7
  108. ncclFloat64 = 8
  109. ncclDouble = 8
  110. ncclBfloat16 = 9
  111. ncclNumTypes = 10
  112. @classmethod
  113. def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t':
  114. if dtype == torch.int8:
  115. return cls.ncclInt8
  116. if dtype == torch.uint8:
  117. return cls.ncclUint8
  118. if dtype == torch.int32:
  119. return cls.ncclInt32
  120. if dtype == torch.int64:
  121. return cls.ncclInt64
  122. if dtype == torch.float16:
  123. return cls.ncclFloat16
  124. if dtype == torch.float32:
  125. return cls.ncclFloat32
  126. if dtype == torch.float64:
  127. return cls.ncclFloat64
  128. if dtype == torch.bfloat16:
  129. return cls.ncclBfloat16
  130. raise ValueError(f"Unsupported dtype: {dtype}")
  131. class ncclRedOp_t(ctypes.c_int):
  132. ncclSum = 0
  133. ncclProd = 1
  134. ncclMax = 2
  135. ncclMin = 3
  136. ncclAvg = 4
  137. ncclNumOps = 5
  138. @classmethod
  139. def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t':
  140. if op == ReduceOp.SUM:
  141. return cls.ncclSum
  142. if op == ReduceOp.PRODUCT:
  143. return cls.ncclProd
  144. if op == ReduceOp.MAX:
  145. return cls.ncclMax
  146. if op == ReduceOp.MIN:
  147. return cls.ncclMin
  148. if op == ReduceOp.AVG:
  149. return cls.ncclAvg
  150. raise ValueError(f"Unsupported op: {op}")
  151. # equivalent to c declaration:
  152. # ncclResult_t ncclAllReduce(
  153. # const void* sendbuff, void* recvbuff, size_t count,
  154. # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
  155. # udaStream_t stream);
  156. # note that cudaStream_t is a pointer type, so the last argument is a pointer
  157. _c_ncclAllReduce = nccl.ncclAllReduce
  158. _c_ncclAllReduce.restype = ctypes.c_int
  159. _c_ncclAllReduce.argtypes = [
  160. ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t,
  161. ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p
  162. ]
  163. # equivalent to c declaration:
  164. # ncclResult_t ncclCommDestroy(ncclComm_t comm);
  165. _c_ncclCommDestroy = nccl.ncclCommDestroy
  166. _c_ncclCommDestroy.restype = ctypes.c_int
  167. _c_ncclCommDestroy.argtypes = [ctypes.c_void_p]
  168. class NCCLCommunicator:
  169. def __init__(
  170. self,
  171. backend=None,
  172. init_method=None,
  173. timeout=datetime.timedelta(seconds=10),
  174. world_size: int = -1,
  175. rank: int = -1,
  176. store=None,
  177. group_name: str = "",
  178. pg_options=None,
  179. local_rank: int = -1,
  180. ):
  181. if not dist.is_initialized():
  182. backend = backend or "nccl"
  183. assert backend == 'nccl', (
  184. "only use nccl backend for starting the NCCL communicator")
  185. dist.init_process_group(backend=backend,
  186. init_method=init_method,
  187. timeout=timeout,
  188. world_size=world_size,
  189. rank=rank,
  190. store=store,
  191. group_name=group_name,
  192. pg_options=pg_options)
  193. self.rank = dist.get_rank()
  194. self.world_size = dist.get_world_size()
  195. if local_rank == -1:
  196. local_rank = self.rank
  197. self.local_rank = local_rank
  198. # don't use these args, as they can be -1
  199. # use `self.rank`, `self.local_rank` and `self.world_size` instead
  200. del world_size, rank, local_rank
  201. torch.cuda.set_device(self.local_rank)
  202. if self.rank == 0:
  203. self.unique_id = ncclGetUniqueId()
  204. else:
  205. self.unique_id = NcclUniqueId()
  206. tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda(
  207. self.local_rank)
  208. dist.broadcast(tensor, src=0)
  209. byte_list = tensor.cpu().tolist()
  210. for i, byte in enumerate(byte_list):
  211. self.unique_id.internal[i] = byte
  212. self.comm = ctypes.c_void_p()
  213. result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
  214. self.unique_id, self.rank)
  215. assert result == 0
  216. self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}")
  217. def all_reduce(self,
  218. tensor: torch.Tensor,
  219. op: ReduceOp = ReduceOp.SUM,
  220. stream=None):
  221. if stream is None:
  222. stream = self.stream
  223. result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
  224. ctypes.c_void_p(tensor.data_ptr()),
  225. tensor.numel(),
  226. ncclDataType_t.from_torch(tensor.dtype),
  227. ncclRedOp_t.from_torch(op), self.comm,
  228. ctypes.c_void_p(stream.cuda_stream))
  229. assert result == 0
  230. def __del__(self):
  231. # `dist` module might have been already destroyed
  232. if hasattr(dist, 'destroy_process_group'):
  233. dist.destroy_process_group()
  234. # function might have been already destroyed
  235. if _c_ncclCommDestroy is not None:
  236. _c_ncclCommDestroy(self.comm)