custom_all_reduce.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. import os
  2. from contextlib import contextmanager
  3. from typing import Any, List, Optional, Union
  4. import torch
  5. import torch.distributed as dist
  6. from loguru import logger
  7. from torch.distributed import ProcessGroup
  8. from aphrodite.distributed.parallel_state import (
  9. get_local_rank, get_tensor_model_parallel_cpu_group)
  10. try:
  11. import pynvml
  12. from aphrodite._C import custom_ar
  13. @contextmanager
  14. def _nvml():
  15. try:
  16. pynvml.nvmlInit()
  17. yield
  18. finally:
  19. pynvml.nvmlShutdown()
  20. except ImportError:
  21. # For AMD GPUs
  22. custom_ar = None
  23. pynvml = None
  24. @contextmanager
  25. def _nvml():
  26. try:
  27. yield
  28. finally:
  29. pass
  30. @_nvml()
  31. def _is_full_nvlink(device_ids: List[int]) -> bool:
  32. """
  33. query if the set of gpus are fully connected by nvlink (1 hop)
  34. Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`,
  35. so it works on real physical device ids.
  36. """
  37. handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
  38. for i, handle in enumerate(handles):
  39. for j, peer_handle in enumerate(handles):
  40. if i < j:
  41. try:
  42. p2p_status = pynvml.nvmlDeviceGetP2PStatus(
  43. handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
  44. if p2p_status != pynvml.NVML_P2P_STATUS_OK:
  45. return False
  46. except pynvml.NVMLError as error:
  47. logger.error(
  48. "NVLink detection failed. This is normal if your"
  49. " machine has no NVLink equipped.",
  50. exc_info=error)
  51. return False
  52. return True
  53. def _can_p2p(rank: int, world_size: int) -> bool:
  54. from aphrodite.distributed.utils import gpu_p2p_access_check
  55. for i in range(world_size):
  56. if i == rank:
  57. continue
  58. if not gpu_p2p_access_check(rank, i):
  59. return False
  60. return True
  61. class CustomAllreduce:
  62. _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
  63. # max_size: max supported allreduce size
  64. def __init__(self,
  65. group: Optional[ProcessGroup] = None,
  66. device: Optional[Union[int, str, torch.device]] = None,
  67. max_size=8192 * 1024) -> None:
  68. """
  69. Args:
  70. group: the process group to work on. If None, it will use the
  71. default process group.
  72. device: the device to bind the CustomAllreduce to. If None,
  73. it will be bind to f"cuda:{local_rank}".
  74. It is the caller's responsibility to make sure each communicator
  75. is bind to a unique device, and all communicators in this group
  76. are in the same node.
  77. """
  78. self._IS_CAPTURING = False
  79. self.disabled = True
  80. if custom_ar is None:
  81. # disable because of missing custom allreduce library
  82. # e.g. in a non-cuda environment
  83. return
  84. group = group or get_tensor_model_parallel_cpu_group()
  85. self.group = group
  86. assert dist.get_backend(group) != dist.Backend.NCCL, (
  87. "CustomAllreduce should be attached to a non-NCCL group.")
  88. rank = dist.get_rank(group=self.group)
  89. world_size = dist.get_world_size(group=self.group)
  90. if world_size == 1:
  91. # No need to initialize custom allreduce for single GPU case.
  92. return
  93. if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
  94. logger.warning(
  95. "Custom allreduce is disabled due to an unsupported world"
  96. " size: %d. Supported world sizes: %s. To silence this "
  97. "warning, specify disable_custom_all_reduce=True explicitly.",
  98. world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
  99. return
  100. if device is None:
  101. local_rank = get_local_rank()
  102. device = torch.device(f"cuda:{local_rank}")
  103. elif isinstance(device, int):
  104. device = torch.device(f"cuda:{device}")
  105. elif isinstance(device, str):
  106. device = torch.device(device)
  107. # now `device` is a `torch.device` object
  108. assert isinstance(device, torch.device)
  109. self.device = device
  110. cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  111. if cuda_visible_devices:
  112. device_ids = list(map(int, cuda_visible_devices.split(",")))
  113. else:
  114. device_ids = list(range(torch.cuda.device_count()))
  115. physical_device_id = device_ids[device.index]
  116. tensor = torch.tensor([physical_device_id],
  117. dtype=torch.int,
  118. device="cpu")
  119. gather_list = [
  120. torch.tensor([0], dtype=torch.int, device="cpu")
  121. for _ in range(world_size)
  122. ]
  123. dist.all_gather(gather_list, tensor, group=self.group)
  124. physical_device_ids = [t.item() for t in gather_list]
  125. # test nvlink first, this will filter out most of the cases
  126. # where custom allreduce is not supported
  127. # this checks hardware and driver support for NVLink
  128. full_nvlink = _is_full_nvlink(physical_device_ids)
  129. if world_size > 2 and not full_nvlink:
  130. logger.warning(
  131. "Custom allreduce is disabled because it's not supported on"
  132. " more than two PCIe-only GPUs. To silence this warning, "
  133. "specify disable_custom_all_reduce=True explicitly.")
  134. return
  135. # test P2P capability, this checks software/cudaruntime support
  136. # this is expensive to compute at the first time
  137. # then we cache the result
  138. if not _can_p2p(rank, world_size):
  139. logger.warning(
  140. "Custom allreduce is disabled because your platform lacks "
  141. "GPU P2P capability or P2P test failed. To silence this "
  142. "warning, specify disable_custom_all_reduce=True explicitly.")
  143. return
  144. self.disabled = False
  145. # buffers memory are owned by this Python class and passed to C++
  146. # meta data composes of two parts: meta data for synchronization
  147. # (256 bytes) and a temporary buffer for storing intermediate
  148. # allreduce results.
  149. self.meta = torch.zeros(custom_ar.meta_size() + max_size,
  150. dtype=torch.uint8,
  151. device=self.device)
  152. # This is a pre-registered IPC buffer. In eager mode, input tensors
  153. # are first copied into this buffer before allreduce is performed
  154. self.buffer = torch.empty(max_size,
  155. dtype=torch.uint8,
  156. device=self.device)
  157. # This is a buffer for storing the tuples of pointers pointing to
  158. # IPC buffers from all ranks. Each registered tuple has size of
  159. # 8*world_size bytes where world_size is at most 8. Allocating 8MB
  160. # is enough for 131072 such tuples. The largest model I've seen only
  161. # needs less than 10000 of registered tuples.
  162. self.rank_data = torch.empty(8 * 1024 * 1024,
  163. dtype=torch.uint8,
  164. device=self.device)
  165. self.max_size = max_size
  166. self.rank = rank
  167. self.world_size = world_size
  168. handles, offsets = self._get_ipc_meta(self.meta)
  169. self.full_nvlink = full_nvlink
  170. self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data,
  171. handles, offsets, rank,
  172. self.full_nvlink)
  173. self.register_buffer(self.buffer)
  174. @contextmanager
  175. def capture(self):
  176. """
  177. The main responsibility of this context manager is the
  178. `register_graph_buffers` call at the end of the context.
  179. It records all the buffer addresses used in the CUDA graph.
  180. """
  181. try:
  182. self._IS_CAPTURING = True
  183. yield
  184. finally:
  185. self._IS_CAPTURING = False
  186. if not self.disabled:
  187. self.register_graph_buffers()
  188. def _get_ipc_meta(self, inp: torch.Tensor):
  189. data = inp.untyped_storage()._share_cuda_()
  190. shard_data = (
  191. data[1], # ipc handle to base ptr
  192. data[3], # offset of base ptr
  193. )
  194. return self._gather_ipc_meta(shard_data)
  195. def _gather_ipc_meta(self, shard_data):
  196. # Note: don't use `[[None]] * self.world_size` here
  197. # because it will create a list of the same reference
  198. all_data: List[Optional[Any]] = [[None]
  199. for i in range(self.world_size)]
  200. all_data[self.rank][0] = shard_data
  201. ranks = dist.get_process_group_ranks(group=self.group)
  202. ranks.sort()
  203. for i, rank in enumerate(ranks):
  204. dist.broadcast_object_list(all_data[i],
  205. src=rank,
  206. group=self.group,
  207. device="cpu")
  208. # we cannot directly use `dist.all_gather_object` here
  209. # because it is incompatible with `gloo` backend under inference mode.
  210. # see https://github.com/pytorch/pytorch/issues/126032 for details.
  211. handles = []
  212. offsets = []
  213. for i in range(len(all_data)):
  214. handles.append(all_data[i][0][0]) # type: ignore
  215. offsets.append(all_data[i][0][1]) # type: ignore
  216. return handles, offsets
  217. def register_buffer(self, inp: torch.Tensor):
  218. handles, offsets = self._get_ipc_meta(inp)
  219. custom_ar.register_buffer(self._ptr, inp, handles, offsets)
  220. def register_graph_buffers(self):
  221. handle, offset = custom_ar.get_graph_buffer_ipc_meta(self._ptr)
  222. handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
  223. logger.info("Registering %d cuda graph addresses", len(offset))
  224. custom_ar.register_graph_buffers(self._ptr, handles, offsets)
  225. def should_custom_ar(self, inp: torch.Tensor):
  226. return custom_ar.should_custom_ar(inp, self.max_size, self.world_size,
  227. self.full_nvlink)
  228. # all reduce, assuming inp tensor is IPC registered with register_buffer,
  229. # or, in the context of cuda graphs, register_graph_buffers
  230. def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
  231. if out is None:
  232. out = torch.empty_like(inp)
  233. custom_ar.all_reduce_reg(self._ptr, inp, out)
  234. return out
  235. # all reduce, assuming inp tensor is NOT IPC registered
  236. def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
  237. if out is None:
  238. out = torch.empty_like(inp)
  239. custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out)
  240. return out
  241. def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
  242. # when custom allreduce is disabled, this will be None
  243. if self.disabled:
  244. return None
  245. if self._IS_CAPTURING:
  246. if torch.cuda.is_current_stream_capturing():
  247. if self.should_custom_ar(input):
  248. return self.all_reduce_reg(input)
  249. else:
  250. if self.should_custom_ar(input):
  251. # if warm up, mimic the allocation pattern
  252. # since custom allreduce is out-of-place
  253. return torch.empty_like(input)
  254. else:
  255. # note: outside of cuda graph context,
  256. # custom allreduce incurs a cost of cudaMemcpy, which should
  257. # be small(<=1% of overall latency) compared to the performance
  258. # gains of using custom kernels
  259. if self.should_custom_ar(input):
  260. return self.all_reduce_unreg(input)
  261. return None
  262. def close(self):
  263. if not self.disabled and self._ptr:
  264. custom_ar.dispose(self._ptr)
  265. self._ptr = 0
  266. def __del__(self):
  267. self.close()