pynccl_wrapper.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. # This file is a pure Python wrapper for the NCCL library.
  2. # The main purpose is to use NCCL combined with CUDA graph.
  3. # Before writing this script, we tried the following approach:
  4. # 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
  5. # often gets stuck when initializing the NCCL communicator.
  6. # 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
  7. # contains many other potential cuda APIs, that are not allowed during
  8. # capturing the CUDA graph. For further details, please check
  9. # https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
  10. #
  11. # Another rejected idea is to write a C/C++ binding for NCCL. It is usually
  12. # doable, but we often encounter issues related with nccl versions, and need
  13. # to switch between different versions of NCCL. See
  14. # https://github.com/NVIDIA/nccl/issues/1234 for more details.
  15. # A C/C++ binding is not flexible enough to handle this. It requires
  16. # recompilation of the code every time we want to switch between different
  17. # versions. This current implementation, with a **pure** Python wrapper, is
  18. # more flexible. We can easily switch between different versions of NCCL by
  19. # changing the environment variable `APHRODITE_NCCL_SO_PATH`, or the `so_file`
  20. # variable in the code.
  21. import ctypes
  22. import platform
  23. from dataclasses import dataclass
  24. from typing import Any, Dict, List, Optional
  25. import torch
  26. from loguru import logger
  27. from torch.distributed import ReduceOp
  28. from aphrodite.common.utils import find_nccl_library
  29. # === export types and functions from nccl to Python ===
  30. # for the original nccl definition, please check
  31. # https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
  32. ncclResult_t = ctypes.c_int
  33. ncclComm_t = ctypes.c_void_p
  34. class ncclUniqueId(ctypes.Structure):
  35. _fields_ = [("internal", ctypes.c_byte * 128)]
  36. cudaStream_t = ctypes.c_void_p
  37. buffer_type = ctypes.c_void_p
  38. ncclDataType_t = ctypes.c_int
  39. class ncclDataTypeEnum:
  40. ncclInt8 = 0
  41. ncclChar = 0
  42. ncclUint8 = 1
  43. ncclInt32 = 2
  44. ncclInt = 2
  45. ncclUint32 = 3
  46. ncclInt64 = 4
  47. ncclUint64 = 5
  48. ncclFloat16 = 6
  49. ncclHalf = 6
  50. ncclFloat32 = 7
  51. ncclFloat = 7
  52. ncclFloat64 = 8
  53. ncclDouble = 8
  54. ncclBfloat16 = 9
  55. ncclNumTypes = 10
  56. @classmethod
  57. def from_torch(cls, dtype: torch.dtype) -> int:
  58. if dtype == torch.int8:
  59. return cls.ncclInt8
  60. if dtype == torch.uint8:
  61. return cls.ncclUint8
  62. if dtype == torch.int32:
  63. return cls.ncclInt32
  64. if dtype == torch.int64:
  65. return cls.ncclInt64
  66. if dtype == torch.float16:
  67. return cls.ncclFloat16
  68. if dtype == torch.float32:
  69. return cls.ncclFloat32
  70. if dtype == torch.float64:
  71. return cls.ncclFloat64
  72. if dtype == torch.bfloat16:
  73. return cls.ncclBfloat16
  74. raise ValueError(f"Unsupported dtype: {dtype}")
  75. ncclRedOp_t = ctypes.c_int
  76. class ncclRedOpTypeEnum:
  77. ncclSum = 0
  78. ncclProd = 1
  79. ncclMax = 2
  80. ncclMin = 3
  81. ncclAvg = 4
  82. ncclNumOps = 5
  83. @classmethod
  84. def from_torch(cls, op: ReduceOp) -> int:
  85. if op == ReduceOp.SUM:
  86. return cls.ncclSum
  87. if op == ReduceOp.PRODUCT:
  88. return cls.ncclProd
  89. if op == ReduceOp.MAX:
  90. return cls.ncclMax
  91. if op == ReduceOp.MIN:
  92. return cls.ncclMin
  93. if op == ReduceOp.AVG:
  94. return cls.ncclAvg
  95. raise ValueError(f"Unsupported op: {op}")
  96. @dataclass
  97. class Function:
  98. name: str
  99. restype: Any
  100. argtypes: List[Any]
  101. class NCCLLibrary:
  102. exported_functions = [
  103. # const char* ncclGetErrorString(ncclResult_t result)
  104. Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
  105. # ncclResult_t ncclGetVersion(int *version);
  106. Function("ncclGetVersion", ncclResult_t,
  107. [ctypes.POINTER(ctypes.c_int)]),
  108. # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
  109. Function("ncclGetUniqueId", ncclResult_t,
  110. [ctypes.POINTER(ncclUniqueId)]),
  111. # ncclResult_t ncclCommInitRank(
  112. # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
  113. # note that ncclComm_t is a pointer type, so the first argument
  114. # is a pointer to a pointer
  115. Function("ncclCommInitRank", ncclResult_t, [
  116. ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId,
  117. ctypes.c_int
  118. ]),
  119. # ncclResult_t ncclAllReduce(
  120. # const void* sendbuff, void* recvbuff, size_t count,
  121. # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
  122. # cudaStream_t stream);
  123. # note that cudaStream_t is a pointer type, so the last argument
  124. # is a pointer
  125. Function("ncclAllReduce", ncclResult_t, [
  126. buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
  127. ncclRedOp_t, ncclComm_t, cudaStream_t
  128. ]),
  129. # ncclResult_t ncclSend(
  130. # const void* sendbuff, size_t count, ncclDataType_t datatype,
  131. # int dest, ncclComm_t comm, cudaStream_t stream);
  132. Function("ncclSend", ncclResult_t, [
  133. buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
  134. ncclComm_t, cudaStream_t
  135. ]),
  136. # ncclResult_t ncclRecv(
  137. # void* recvbuff, size_t count, ncclDataType_t datatype,
  138. # int src, ncclComm_t comm, cudaStream_t stream);
  139. Function("ncclRecv", ncclResult_t, [
  140. buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
  141. ncclComm_t, cudaStream_t
  142. ]),
  143. # be cautious! this is a collective call, it will block until all
  144. # processes in the communicator have called this function.
  145. # because Python object destruction can happen in random order,
  146. # it is better not to call it at all.
  147. # ncclResult_t ncclCommDestroy(ncclComm_t comm);
  148. Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
  149. ]
  150. # class attribute to store the mapping from the path to the library
  151. # to avoid loading the same library multiple times
  152. path_to_library_cache: Dict[str, Any] = {}
  153. # class attribute to store the mapping from library path
  154. # to the corresponding dictionary
  155. path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
  156. def __init__(self, so_file: Optional[str] = None):
  157. so_file = so_file or find_nccl_library()
  158. try:
  159. if so_file not in NCCLLibrary.path_to_dict_mapping:
  160. lib = ctypes.CDLL(so_file)
  161. NCCLLibrary.path_to_library_cache[so_file] = lib
  162. self.lib = NCCLLibrary.path_to_library_cache[so_file]
  163. except Exception as e:
  164. logger.error(
  165. f"Failed to load NCCL library from {so_file} ."
  166. "It is expected if you are not running on NVIDIA/AMD GPUs."
  167. "Otherwise, the nccl library might not exist, be corrupted "
  168. "or it does not support the current platform "
  169. f"{platform.platform()}. If you already have the library, "
  170. "please set the environment variable APHRODITE_NCCL_SO_PATH"
  171. " to point to the correct nccl library path.")
  172. raise e
  173. if so_file not in NCCLLibrary.path_to_dict_mapping:
  174. _funcs = {}
  175. for func in NCCLLibrary.exported_functions:
  176. f = getattr(self.lib, func.name)
  177. f.restype = func.restype
  178. f.argtypes = func.argtypes
  179. _funcs[func.name] = f
  180. NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
  181. self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]
  182. def ncclGetErrorString(self, result: ncclResult_t) -> str:
  183. return self._funcs["ncclGetErrorString"](result).decode("utf-8")
  184. def NCCL_CHECK(self, result: ncclResult_t) -> None:
  185. if result != 0:
  186. error_str = self.ncclGetErrorString(result)
  187. raise RuntimeError(f"NCCL error: {error_str}")
  188. def ncclGetVersion(self) -> str:
  189. version = ctypes.c_int()
  190. self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
  191. version_str = str(version.value)
  192. # something like 21903 --> "2.19.3"
  193. major = version_str[0].lstrip("0")
  194. minor = version_str[1:3].lstrip("0")
  195. patch = version_str[3:].lstrip("0")
  196. return f"{major}.{minor}.{patch}"
  197. def ncclGetUniqueId(self) -> ncclUniqueId:
  198. unique_id = ncclUniqueId()
  199. self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](
  200. ctypes.byref(unique_id)))
  201. return unique_id
  202. def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
  203. rank: int) -> ncclComm_t:
  204. comm = ncclComm_t()
  205. self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm),
  206. world_size, unique_id,
  207. rank))
  208. return comm
  209. def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
  210. count: int, datatype: int, op: int, comm: ncclComm_t,
  211. stream: cudaStream_t) -> None:
  212. # `datatype` actually should be `ncclDataType_t`
  213. # and `op` should be `ncclRedOp_t`
  214. # both are aliases of `ctypes.c_int`
  215. # when we pass int to a function, it will be converted to `ctypes.c_int`
  216. # by ctypes automatically
  217. self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count,
  218. datatype, op, comm,
  219. stream))
  220. def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
  221. dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
  222. self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype,
  223. dest, comm, stream))
  224. def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int,
  225. src: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
  226. self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src,
  227. comm, stream))
  228. def ncclCommDestroy(self, comm: ncclComm_t) -> None:
  229. self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
  230. __all__ = [
  231. "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
  232. "ncclComm_t", "cudaStream_t", "buffer_type"
  233. ]