custom_all_reduce.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. import os
  2. from contextlib import contextmanager
  3. from typing import Any, List, Optional, Union
  4. import torch
  5. import torch.distributed as dist
  6. from loguru import logger
  7. from torch.distributed import ProcessGroup
  8. from aphrodite import _custom_ops as ops
  9. from aphrodite.common.utils import cuda_device_count_stateless
  10. from aphrodite.distributed.device_communicators.custom_all_reduce_utils import (
  11. gpu_p2p_access_check)
  12. from aphrodite.distributed.parallel_state import in_the_same_node_as
  13. from aphrodite.platforms import current_platform
  14. try:
  15. ops.meta_size()
  16. custom_ar = True
  17. except Exception:
  18. # For AMD GPUs and CPUs
  19. custom_ar = False
  20. def _can_p2p(rank: int, world_size: int) -> bool:
  21. for i in range(world_size):
  22. if i == rank:
  23. continue
  24. if not gpu_p2p_access_check(rank, i):
  25. return False
  26. return True
  27. class CustomAllreduce:
  28. _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
  29. # max_size: max supported allreduce size
  30. def __init__(self,
  31. group: ProcessGroup,
  32. device: Union[int, str, torch.device],
  33. max_size=8192 * 1024) -> None:
  34. """
  35. Args:
  36. group: the process group to work on. If None, it will use the
  37. default process group.
  38. device: the device to bind the CustomAllreduce to. If None,
  39. it will be bind to f"cuda:{local_rank}".
  40. It is the caller's responsibility to make sure each communicator
  41. is bind to a unique device, and all communicators in this group
  42. are in the same node.
  43. """
  44. self._IS_CAPTURING = False
  45. self.disabled = True
  46. if not custom_ar:
  47. # disable because of missing custom allreduce library
  48. # e.g. in a non-cuda environment
  49. return
  50. self.group = group
  51. assert dist.get_backend(group) != dist.Backend.NCCL, (
  52. "CustomAllreduce should be attached to a non-NCCL group.")
  53. if not all(in_the_same_node_as(group, source_rank=0)):
  54. # No need to initialize custom allreduce for multi-node case.
  55. logger.warning(
  56. "Custom allreduce is disabled because this process group"
  57. " spans across nodes.")
  58. return
  59. rank = dist.get_rank(group=self.group)
  60. world_size = dist.get_world_size(group=self.group)
  61. if world_size == 1:
  62. # No need to initialize custom allreduce for single GPU case.
  63. return
  64. if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
  65. if rank == 0:
  66. logger.warning(
  67. "Custom allreduce is disabled due to an unsupported world"
  68. f" size: {world_size}. Supported world sizes:"
  69. f"{str(CustomAllreduce._SUPPORTED_WORLD_SIZES)}. To "
  70. "silence this warning, specify disable_custom_all_reduce="
  71. "True explicitly.")
  72. return
  73. if isinstance(device, int):
  74. device = torch.device(f"cuda:{device}")
  75. elif isinstance(device, str):
  76. device = torch.device(device)
  77. # now `device` is a `torch.device` object
  78. assert isinstance(device, torch.device)
  79. self.device = device
  80. cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  81. if cuda_visible_devices:
  82. device_ids = list(map(int, cuda_visible_devices.split(",")))
  83. else:
  84. device_ids = list(range(cuda_device_count_stateless()))
  85. physical_device_id = device_ids[device.index]
  86. tensor = torch.tensor([physical_device_id],
  87. dtype=torch.int,
  88. device="cpu")
  89. gather_list = [
  90. torch.tensor([0], dtype=torch.int, device="cpu")
  91. for _ in range(world_size)
  92. ]
  93. dist.all_gather(gather_list, tensor, group=self.group)
  94. physical_device_ids = [t.item() for t in gather_list]
  95. # test nvlink first, this will filter out most of the cases
  96. # where custom allreduce is not supported
  97. # this checks hardware and driver support for NVLink
  98. assert current_platform.is_cuda()
  99. from aphrodite.platforms.cuda import CudaPlatform
  100. cuda_platform: CudaPlatform = current_platform
  101. full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids)
  102. if world_size > 2 and not full_nvlink:
  103. if rank == 0:
  104. logger.warning(
  105. "Custom allreduce is disabled because it's not supported "
  106. "on more than two PCIe-only GPUs. To silence this "
  107. "warning, specify disable_custom_all_reduce=True "
  108. "explicitly.")
  109. return
  110. # test P2P capability, this checks software/cudaruntime support
  111. # this is expensive to compute at the first time
  112. # then we cache the result
  113. if not _can_p2p(rank, world_size):
  114. if rank == 0:
  115. logger.warning(
  116. "Custom allreduce is disabled because your platform lacks "
  117. "GPU P2P capability or P2P test failed. To silence this "
  118. "warning, specify disable_custom_all_reduce=True "
  119. "explicitly.")
  120. return
  121. self.disabled = False
  122. # buffers memory are owned by this Python class and passed to C++
  123. # meta data composes of two parts: meta data for synchronization
  124. # (256 bytes) and a temporary buffer for storing intermediate
  125. # allreduce results.
  126. self.meta = torch.zeros(ops.meta_size() + max_size,
  127. dtype=torch.uint8,
  128. device=self.device)
  129. # This is a pre-registered IPC buffer. In eager mode, input tensors
  130. # are first copied into this buffer before allreduce is performed
  131. self.buffer = torch.empty(max_size,
  132. dtype=torch.uint8,
  133. device=self.device)
  134. # This is a buffer for storing the tuples of pointers pointing to
  135. # IPC buffers from all ranks. Each registered tuple has size of
  136. # 8*world_size bytes where world_size is at most 8. Allocating 8MB
  137. # is enough for 131072 such tuples. The largest model I've seen only
  138. # needs less than 10000 of registered tuples.
  139. self.rank_data = torch.empty(8 * 1024 * 1024,
  140. dtype=torch.uint8,
  141. device=self.device)
  142. self.max_size = max_size
  143. self.rank = rank
  144. self.world_size = world_size
  145. handles, offsets = self._get_ipc_meta(self.meta)
  146. self.full_nvlink = full_nvlink
  147. self._ptr = ops.init_custom_ar(self.meta, self.rank_data, handles,
  148. offsets, rank, self.full_nvlink)
  149. self.register_buffer(self.buffer)
  150. @contextmanager
  151. def capture(self):
  152. """
  153. The main responsibility of this context manager is the
  154. `register_graph_buffers` call at the end of the context.
  155. It records all the buffer addresses used in the CUDA graph.
  156. """
  157. try:
  158. self._IS_CAPTURING = True
  159. yield
  160. finally:
  161. self._IS_CAPTURING = False
  162. if not self.disabled:
  163. self.register_graph_buffers()
  164. def _get_ipc_meta(self, inp: torch.Tensor):
  165. data = inp.untyped_storage()._share_cuda_()
  166. shard_data = (
  167. data[1], # ipc handle to base ptr
  168. data[3], # offset of base ptr
  169. )
  170. return self._gather_ipc_meta(shard_data)
  171. def _gather_ipc_meta(self, shard_data):
  172. # Note: don't use `[[None]] * self.world_size` here
  173. # because it will create a list of the same reference
  174. all_data: List[Optional[Any]] = [[None]
  175. for i in range(self.world_size)]
  176. all_data[self.rank][0] = shard_data
  177. ranks = dist.get_process_group_ranks(group=self.group)
  178. ranks.sort()
  179. for i, rank in enumerate(ranks):
  180. dist.broadcast_object_list(all_data[i],
  181. src=rank,
  182. group=self.group,
  183. device="cpu")
  184. # we cannot directly use `dist.all_gather_object` here
  185. # because it is incompatible with `gloo` backend under inference mode.
  186. # see https://github.com/pytorch/pytorch/issues/126032 for details.
  187. handles = []
  188. offsets = []
  189. for i in range(len(all_data)):
  190. handles.append(all_data[i][0][0]) # type: ignore
  191. offsets.append(all_data[i][0][1]) # type: ignore
  192. return handles, offsets
  193. def register_buffer(self, inp: torch.Tensor):
  194. handles, offsets = self._get_ipc_meta(inp)
  195. ops.register_buffer(self._ptr, inp, handles, offsets)
  196. def register_graph_buffers(self):
  197. handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
  198. handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
  199. if self.rank == 0:
  200. logger.info(f"Registering {len(offset)} cuda graph addresses")
  201. ops.register_graph_buffers(self._ptr, handles, offsets)
  202. def should_custom_ar(self, inp: torch.Tensor):
  203. return ops.should_custom_ar(inp, self.max_size, self.world_size,
  204. self.full_nvlink)
  205. # all reduce, assuming inp tensor is IPC registered with register_buffer,
  206. # or, in the context of cuda graphs, register_graph_buffers
  207. def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
  208. if out is None:
  209. out = torch.empty_like(inp)
  210. ops.all_reduce_reg(self._ptr, inp, out)
  211. return out
  212. # all reduce, assuming inp tensor is NOT IPC registered
  213. def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
  214. if out is None:
  215. out = torch.empty_like(inp)
  216. ops.all_reduce_unreg(self._ptr, inp, self.buffer, out)
  217. return out
  218. def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
  219. # when custom allreduce is disabled, this will be None
  220. if self.disabled:
  221. return None
  222. if self._IS_CAPTURING:
  223. if torch.cuda.is_current_stream_capturing():
  224. if self.should_custom_ar(input):
  225. return self.all_reduce_reg(input)
  226. else:
  227. if self.should_custom_ar(input):
  228. # if warm up, mimic the allocation pattern
  229. # since custom allreduce is out-of-place
  230. return torch.empty_like(input)
  231. else:
  232. # note: outside of cuda graph context,
  233. # custom allreduce incurs a cost of cudaMemcpy, which should
  234. # be small(<=1% of overall latency) compared to the performance
  235. # gains of using custom kernels
  236. if self.should_custom_ar(input):
  237. return self.all_reduce_unreg(input)
  238. return None
  239. def close(self):
  240. if not self.disabled and self._ptr:
  241. ops.dispose(self._ptr)
  242. self._ptr = 0
  243. def __del__(self):
  244. self.close()