custom_all_reduce.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. from contextlib import contextmanager
  2. from typing import Any, List, Optional, Union
  3. import torch
  4. import torch.distributed as dist
  5. from loguru import logger
  6. from torch.distributed import ProcessGroup
  7. import aphrodite.common.envs as envs
  8. from aphrodite import _custom_ops as ops
  9. from aphrodite.common.utils import cuda_device_count_stateless
  10. from aphrodite.distributed.device_communicators.custom_all_reduce_utils import (
  11. gpu_p2p_access_check)
  12. from aphrodite.distributed.parallel_state import in_the_same_node_as
  13. from aphrodite.platforms import current_platform
  14. try:
  15. ops.meta_size()
  16. custom_ar = True
  17. except Exception:
  18. # For AMD GPUs and CPUs
  19. custom_ar = False
  20. def _can_p2p(rank: int, world_size: int) -> bool:
  21. for i in range(world_size):
  22. if i == rank:
  23. continue
  24. if not gpu_p2p_access_check(rank, i):
  25. return False
  26. return True
  27. def is_weak_contiguous(inp: torch.Tensor):
  28. return inp.is_contiguous() or (inp.storage().nbytes() -
  29. inp.storage_offset() * inp.element_size()
  30. == inp.numel() * inp.element_size())
  31. class CustomAllreduce:
  32. _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
  33. # max_size: max supported allreduce size
  34. def __init__(self,
  35. group: ProcessGroup,
  36. device: Union[int, str, torch.device],
  37. max_size=8192 * 1024) -> None:
  38. """
  39. Args:
  40. group: the process group to work on. If None, it will use the
  41. default process group.
  42. device: the device to bind the CustomAllreduce to. If None,
  43. it will be bind to f"cuda:{local_rank}".
  44. It is the caller's responsibility to make sure each communicator
  45. is bind to a unique device, and all communicators in this group
  46. are in the same node.
  47. """
  48. self._IS_CAPTURING = False
  49. self.disabled = True
  50. if not custom_ar:
  51. # disable because of missing custom allreduce library
  52. # e.g. in a non-cuda environment
  53. return
  54. self.group = group
  55. assert dist.get_backend(group) != dist.Backend.NCCL, (
  56. "CustomAllreduce should be attached to a non-NCCL group.")
  57. if not all(in_the_same_node_as(group, source_rank=0)):
  58. # No need to initialize custom allreduce for multi-node case.
  59. logger.warning(
  60. "Custom allreduce is disabled because this process group"
  61. " spans across nodes.")
  62. return
  63. rank = dist.get_rank(group=self.group)
  64. world_size = dist.get_world_size(group=self.group)
  65. if world_size == 1:
  66. # No need to initialize custom allreduce for single GPU case.
  67. return
  68. if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
  69. if rank == 0:
  70. logger.warning(
  71. "Custom allreduce is disabled due to an unsupported world"
  72. f" size: {world_size}. Supported world sizes:"
  73. f"{str(CustomAllreduce._SUPPORTED_WORLD_SIZES)}. To "
  74. "silence this warning, specify disable_custom_all_reduce="
  75. "True explicitly.")
  76. return
  77. if isinstance(device, int):
  78. device = torch.device(f"cuda:{device}")
  79. elif isinstance(device, str):
  80. device = torch.device(device)
  81. # now `device` is a `torch.device` object
  82. assert isinstance(device, torch.device)
  83. self.device = device
  84. cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
  85. if cuda_visible_devices:
  86. device_ids = list(map(int, cuda_visible_devices.split(",")))
  87. else:
  88. device_ids = list(range(cuda_device_count_stateless()))
  89. physical_device_id = device_ids[device.index]
  90. tensor = torch.tensor([physical_device_id],
  91. dtype=torch.int,
  92. device="cpu")
  93. gather_list = [
  94. torch.tensor([0], dtype=torch.int, device="cpu")
  95. for _ in range(world_size)
  96. ]
  97. dist.all_gather(gather_list, tensor, group=self.group)
  98. physical_device_ids = [t.item() for t in gather_list]
  99. # test nvlink first, this will filter out most of the cases
  100. # where custom allreduce is not supported
  101. # this checks hardware and driver support for NVLink
  102. assert current_platform.is_cuda()
  103. from aphrodite.platforms.cuda import CudaPlatform
  104. cuda_platform: CudaPlatform = current_platform
  105. full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids)
  106. if world_size > 2 and not full_nvlink:
  107. if rank == 0:
  108. logger.warning(
  109. "Custom allreduce is disabled because it's not supported "
  110. "on more than two PCIe-only GPUs. To silence this "
  111. "warning, specify disable_custom_all_reduce=True "
  112. "explicitly.")
  113. return
  114. # test P2P capability, this checks software/cudaruntime support
  115. # this is expensive to compute at the first time
  116. # then we cache the result
  117. if not _can_p2p(rank, world_size):
  118. if rank == 0:
  119. logger.warning(
  120. "Custom allreduce is disabled because your platform lacks "
  121. "GPU P2P capability or P2P test failed. To silence this "
  122. "warning, specify disable_custom_all_reduce=True "
  123. "explicitly.")
  124. return
  125. self.disabled = False
  126. # buffers memory are owned by this Python class and passed to C++
  127. # meta data composes of two parts: meta data for synchronization
  128. # (256 bytes) and a temporary buffer for storing intermediate
  129. # allreduce results.
  130. self.meta = torch.zeros(ops.meta_size() + max_size,
  131. dtype=torch.uint8,
  132. device=self.device)
  133. # This is a pre-registered IPC buffer. In eager mode, input tensors
  134. # are first copied into this buffer before allreduce is performed
  135. self.buffer = torch.empty(max_size,
  136. dtype=torch.uint8,
  137. device=self.device)
  138. # This is a buffer for storing the tuples of pointers pointing to
  139. # IPC buffers from all ranks. Each registered tuple has size of
  140. # 8*world_size bytes where world_size is at most 8. Allocating 8MB
  141. # is enough for 131072 such tuples. The largest model I've seen only
  142. # needs less than 10000 of registered tuples.
  143. self.rank_data = torch.empty(8 * 1024 * 1024,
  144. dtype=torch.uint8,
  145. device=self.device)
  146. self.max_size = max_size
  147. self.rank = rank
  148. self.world_size = world_size
  149. handles, offsets = self._get_ipc_meta(self.meta)
  150. self.full_nvlink = full_nvlink
  151. self._ptr = ops.init_custom_ar(self.meta, self.rank_data, handles,
  152. offsets, rank, self.full_nvlink)
  153. self.register_buffer(self.buffer)
  154. @contextmanager
  155. def capture(self):
  156. """
  157. The main responsibility of this context manager is the
  158. `register_graph_buffers` call at the end of the context.
  159. It records all the buffer addresses used in the CUDA graph.
  160. """
  161. try:
  162. self._IS_CAPTURING = True
  163. yield
  164. finally:
  165. self._IS_CAPTURING = False
  166. if not self.disabled:
  167. self.register_graph_buffers()
  168. def _get_ipc_meta(self, inp: torch.Tensor):
  169. data = inp.untyped_storage()._share_cuda_()
  170. shard_data = (
  171. data[1], # ipc handle to base ptr
  172. data[3], # offset of base ptr
  173. )
  174. return self._gather_ipc_meta(shard_data)
  175. def _gather_ipc_meta(self, shard_data):
  176. # Note: don't use `[[None]] * self.world_size` here
  177. # because it will create a list of the same reference
  178. all_data: List[Optional[Any]] = [[None]
  179. for i in range(self.world_size)]
  180. all_data[self.rank][0] = shard_data
  181. ranks = dist.get_process_group_ranks(group=self.group)
  182. ranks.sort()
  183. for i, rank in enumerate(ranks):
  184. dist.broadcast_object_list(all_data[i],
  185. src=rank,
  186. group=self.group,
  187. device="cpu")
  188. # we cannot directly use `dist.all_gather_object` here
  189. # because it is incompatible with `gloo` backend under inference mode.
  190. # see https://github.com/pytorch/pytorch/issues/126032 for details.
  191. handles = []
  192. offsets = []
  193. for i in range(len(all_data)):
  194. handles.append(all_data[i][0][0]) # type: ignore
  195. offsets.append(all_data[i][0][1]) # type: ignore
  196. return handles, offsets
  197. def register_buffer(self, inp: torch.Tensor):
  198. handles, offsets = self._get_ipc_meta(inp)
  199. ops.register_buffer(self._ptr, inp, handles, offsets)
  200. def register_graph_buffers(self):
  201. handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
  202. handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
  203. if self.rank == 0:
  204. logger.info(f"Registering {len(offset)} cuda graph addresses")
  205. ops.register_graph_buffers(self._ptr, handles, offsets)
  206. def should_custom_ar(self, inp: torch.Tensor):
  207. if self.disabled:
  208. return False
  209. inp_size = inp.numel() * inp.element_size()
  210. # custom allreduce requires input byte size to be multiples of 16
  211. if inp_size % 16 != 0:
  212. return False
  213. if not is_weak_contiguous(inp):
  214. return False
  215. # for 4 or more non NVLink-capable GPUs, custom allreduce provides
  216. # little performance improvement over NCCL.
  217. if self.world_size == 2 or self.full_nvlink:
  218. return inp_size < self.max_size
  219. return False
  220. # all reduce, assuming inp tensor is IPC registered with register_buffer,
  221. # or, in the context of cuda graphs, register_graph_buffers
  222. def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
  223. if out is None:
  224. out = torch.empty_like(inp)
  225. ops.all_reduce_reg(self._ptr, inp, out)
  226. return out
  227. # all reduce, assuming inp tensor is NOT IPC registered
  228. def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
  229. if out is None:
  230. out = torch.empty_like(inp)
  231. ops.all_reduce_unreg(self._ptr, inp, self.buffer, out)
  232. return out
  233. def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
  234. # when custom allreduce is disabled, this will be None
  235. if self.disabled:
  236. return None
  237. if self._IS_CAPTURING:
  238. if torch.cuda.is_current_stream_capturing():
  239. if self.should_custom_ar(input):
  240. return self.all_reduce_reg(input)
  241. else:
  242. if self.should_custom_ar(input):
  243. # if warm up, mimic the allocation pattern
  244. # since custom allreduce is out-of-place
  245. return torch.empty_like(input)
  246. else:
  247. # note: outside of cuda graph context,
  248. # custom allreduce incurs a cost of cudaMemcpy, which should
  249. # be small(<=1% of overall latency) compared to the performance
  250. # gains of using custom kernels
  251. if self.should_custom_ar(input):
  252. return self.all_reduce_unreg(input)
  253. return None
  254. def close(self):
  255. if not self.disabled and self._ptr:
  256. ops.dispose(self._ptr)
  257. self._ptr = 0
  258. def __del__(self):
  259. self.close()