custom_all_reduce.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import os
  2. from contextlib import contextmanager
  3. from typing import Any, List, Optional, Union
  4. import torch
  5. import torch.distributed as dist
  6. from loguru import logger
  7. from torch.distributed import ProcessGroup
  8. from aphrodite import _custom_ops as ops
  9. from aphrodite.common.utils import cuda_device_count_stateless, is_full_nvlink
  10. from aphrodite.distributed.device_communicators.custom_all_reduce_utils import \
  11. gpu_p2p_access_check
  12. from aphrodite.distributed.parallel_state import in_the_same_node_as
  13. try:
  14. assert ops.is_custom_op_supported("_C_custom_ar::meta_size")
  15. custom_ar = True
  16. except Exception:
  17. # For AMD GPUs and CPUs
  18. custom_ar = False
  19. def _can_p2p(rank: int, world_size: int) -> bool:
  20. for i in range(world_size):
  21. if i == rank:
  22. continue
  23. if not gpu_p2p_access_check(rank, i):
  24. return False
  25. return True
  26. class CustomAllreduce:
  27. _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
  28. # max_size: max supported allreduce size
  29. def __init__(self,
  30. group: ProcessGroup,
  31. device: Union[int, str, torch.device],
  32. max_size=8192 * 1024) -> None:
  33. """
  34. Args:
  35. group: the process group to work on. If None, it will use the
  36. default process group.
  37. device: the device to bind the CustomAllreduce to. If None,
  38. it will be bind to f"cuda:{local_rank}".
  39. It is the caller's responsibility to make sure each communicator
  40. is bind to a unique device, and all communicators in this group
  41. are in the same node.
  42. """
  43. self._IS_CAPTURING = False
  44. self.disabled = True
  45. if not custom_ar:
  46. # disable because of missing custom allreduce library
  47. # e.g. in a non-cuda environment
  48. return
  49. self.group = group
  50. assert dist.get_backend(group) != dist.Backend.NCCL, (
  51. "CustomAllreduce should be attached to a non-NCCL group.")
  52. if not all(in_the_same_node_as(group, source_rank=0)):
  53. # No need to initialize custom allreduce for multi-node case.
  54. logger.warning(
  55. "Custom allreduce is disabled because this process group"
  56. " spans across nodes.")
  57. return
  58. rank = dist.get_rank(group=self.group)
  59. world_size = dist.get_world_size(group=self.group)
  60. if world_size == 1:
  61. # No need to initialize custom allreduce for single GPU case.
  62. return
  63. if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
  64. logger.warning(
  65. "Custom allreduce is disabled due to an unsupported world"
  66. f" size: {world_size}. Supported world sizes:"
  67. f"{str(CustomAllreduce._SUPPORTED_WORLD_SIZES)}. To silence "
  68. "this warning, specify disable_custom_all_reduce=True "
  69. "explicitly.")
  70. return
  71. if isinstance(device, int):
  72. device = torch.device(f"cuda:{device}")
  73. elif isinstance(device, str):
  74. device = torch.device(device)
  75. # now `device` is a `torch.device` object
  76. assert isinstance(device, torch.device)
  77. self.device = device
  78. cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  79. if cuda_visible_devices:
  80. device_ids = list(map(int, cuda_visible_devices.split(",")))
  81. else:
  82. device_ids = list(range(cuda_device_count_stateless()))
  83. physical_device_id = device_ids[device.index]
  84. tensor = torch.tensor([physical_device_id],
  85. dtype=torch.int,
  86. device="cpu")
  87. gather_list = [
  88. torch.tensor([0], dtype=torch.int, device="cpu")
  89. for _ in range(world_size)
  90. ]
  91. dist.all_gather(gather_list, tensor, group=self.group)
  92. physical_device_ids = [t.item() for t in gather_list]
  93. # test nvlink first, this will filter out most of the cases
  94. # where custom allreduce is not supported
  95. # this checks hardware and driver support for NVLink
  96. full_nvlink = is_full_nvlink(physical_device_ids)
  97. if world_size > 2 and not full_nvlink:
  98. logger.warning(
  99. "Custom allreduce is disabled because it's not supported on"
  100. " more than two PCIe-only GPUs. To silence this warning, "
  101. "specify disable_custom_all_reduce=True explicitly.")
  102. return
  103. # test P2P capability, this checks software/cudaruntime support
  104. # this is expensive to compute at the first time
  105. # then we cache the result
  106. if not _can_p2p(rank, world_size):
  107. logger.warning(
  108. "Custom allreduce is disabled because your platform lacks "
  109. "GPU P2P capability or P2P test failed. To silence this "
  110. "warning, specify disable_custom_all_reduce=True explicitly.")
  111. return
  112. self.disabled = False
  113. # buffers memory are owned by this Python class and passed to C++
  114. # meta data composes of two parts: meta data for synchronization
  115. # (256 bytes) and a temporary buffer for storing intermediate
  116. # allreduce results.
  117. self.meta = torch.zeros(ops.meta_size() + max_size,
  118. dtype=torch.uint8,
  119. device=self.device)
  120. # This is a pre-registered IPC buffer. In eager mode, input tensors
  121. # are first copied into this buffer before allreduce is performed
  122. self.buffer = torch.empty(max_size,
  123. dtype=torch.uint8,
  124. device=self.device)
  125. # This is a buffer for storing the tuples of pointers pointing to
  126. # IPC buffers from all ranks. Each registered tuple has size of
  127. # 8*world_size bytes where world_size is at most 8. Allocating 8MB
  128. # is enough for 131072 such tuples. The largest model I've seen only
  129. # needs less than 10000 of registered tuples.
  130. self.rank_data = torch.empty(8 * 1024 * 1024,
  131. dtype=torch.uint8,
  132. device=self.device)
  133. self.max_size = max_size
  134. self.rank = rank
  135. self.world_size = world_size
  136. handles, offsets = self._get_ipc_meta(self.meta)
  137. self.full_nvlink = full_nvlink
  138. self._ptr = ops.init_custom_ar(self.meta, self.rank_data, handles,
  139. offsets, rank, self.full_nvlink)
  140. self.register_buffer(self.buffer)
  141. @contextmanager
  142. def capture(self):
  143. """
  144. The main responsibility of this context manager is the
  145. `register_graph_buffers` call at the end of the context.
  146. It records all the buffer addresses used in the CUDA graph.
  147. """
  148. try:
  149. self._IS_CAPTURING = True
  150. yield
  151. finally:
  152. self._IS_CAPTURING = False
  153. if not self.disabled:
  154. self.register_graph_buffers()
  155. def _get_ipc_meta(self, inp: torch.Tensor):
  156. data = inp.untyped_storage()._share_cuda_()
  157. shard_data = (
  158. data[1], # ipc handle to base ptr
  159. data[3], # offset of base ptr
  160. )
  161. return self._gather_ipc_meta(shard_data)
  162. def _gather_ipc_meta(self, shard_data):
  163. # Note: don't use `[[None]] * self.world_size` here
  164. # because it will create a list of the same reference
  165. all_data: List[Optional[Any]] = [[None]
  166. for i in range(self.world_size)]
  167. all_data[self.rank][0] = shard_data
  168. ranks = dist.get_process_group_ranks(group=self.group)
  169. ranks.sort()
  170. for i, rank in enumerate(ranks):
  171. dist.broadcast_object_list(all_data[i],
  172. src=rank,
  173. group=self.group,
  174. device="cpu")
  175. # we cannot directly use `dist.all_gather_object` here
  176. # because it is incompatible with `gloo` backend under inference mode.
  177. # see https://github.com/pytorch/pytorch/issues/126032 for details.
  178. handles = []
  179. offsets = []
  180. for i in range(len(all_data)):
  181. handles.append(all_data[i][0][0]) # type: ignore
  182. offsets.append(all_data[i][0][1]) # type: ignore
  183. return handles, offsets
  184. def register_buffer(self, inp: torch.Tensor):
  185. handles, offsets = self._get_ipc_meta(inp)
  186. ops.register_buffer(self._ptr, inp, handles, offsets)
  187. def register_graph_buffers(self):
  188. handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
  189. handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
  190. logger.info(f"Registering {len(offset)} cuda graph addresses")
  191. ops.register_graph_buffers(self._ptr, handles, offsets)
  192. def should_custom_ar(self, inp: torch.Tensor):
  193. return ops.should_custom_ar(inp, self.max_size, self.world_size,
  194. self.full_nvlink)
  195. # all reduce, assuming inp tensor is IPC registered with register_buffer,
  196. # or, in the context of cuda graphs, register_graph_buffers
  197. def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
  198. if out is None:
  199. out = torch.empty_like(inp)
  200. ops.all_reduce_reg(self._ptr, inp, out)
  201. return out
  202. # all reduce, assuming inp tensor is NOT IPC registered
  203. def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
  204. if out is None:
  205. out = torch.empty_like(inp)
  206. ops.all_reduce_unreg(self._ptr, inp, self.buffer, out)
  207. return out
  208. def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
  209. # when custom allreduce is disabled, this will be None
  210. if self.disabled:
  211. return None
  212. if self._IS_CAPTURING:
  213. if torch.cuda.is_current_stream_capturing():
  214. if self.should_custom_ar(input):
  215. return self.all_reduce_reg(input)
  216. else:
  217. if self.should_custom_ar(input):
  218. # if warm up, mimic the allocation pattern
  219. # since custom allreduce is out-of-place
  220. return torch.empty_like(input)
  221. else:
  222. # note: outside of cuda graph context,
  223. # custom allreduce incurs a cost of cudaMemcpy, which should
  224. # be small(<=1% of overall latency) compared to the performance
  225. # gains of using custom kernels
  226. if self.should_custom_ar(input):
  227. return self.all_reduce_unreg(input)
  228. return None
  229. def close(self):
  230. if not self.disabled and self._ptr:
  231. ops.dispose(self._ptr)
  232. self._ptr = 0
  233. def __del__(self):
  234. self.close()