custom_all_reduce.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. from contextlib import contextmanager
  2. from typing import Optional
  3. from loguru import logger
  4. import torch
  5. import torch.distributed as dist
  6. from aphrodite.distributed import gpu_p2p_access_check
  7. try:
  8. from aphrodite._C import custom_ar
  9. import pynvml
  10. except ImportError:
  11. # For AMD GPUs
  12. custom_ar = None
  13. pynvml = None
  14. _CA_HANDLE = None
  15. _IS_CAPTURING = False
  16. _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
  17. def init_custom_ar() -> None:
  18. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  19. get_tensor_model_parallel_world_size)
  20. global _CA_HANDLE
  21. if _CA_HANDLE is not None:
  22. return
  23. rank = get_tensor_model_parallel_rank()
  24. world_size = get_tensor_model_parallel_world_size()
  25. if world_size == 1:
  26. return
  27. if world_size not in _SUPPORTED_WORLD_SIZES:
  28. logger.warning(
  29. "Custom allreduce is disabled due to an unsupported world size: "
  30. "%d. Supported world sizes: %s. To silence this warning, specify"
  31. " disable_custom_all_reduce=True explicitly.", world_size,
  32. str(_SUPPORTED_WORLD_SIZES))
  33. return
  34. num_dev = torch.cuda.device_count()
  35. # note: num dev can be larger than world_size if we're only using
  36. # first few GPUs
  37. if num_dev < world_size:
  38. logger.warning(
  39. "Cannot test GPU P2P because not all GPUs are visible to the "
  40. "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
  41. " is set.")
  42. return False
  43. # test nvlink first, this will filter out most of the cases
  44. # where custom allreduce is not supported
  45. full_nvlink = _is_full_nvlink(rank, world_size)
  46. if world_size > 2 and not full_nvlink:
  47. logger.warning(
  48. "Custom allreduce is disabled because it's not supported on more"
  49. " than two PCIe-only GPUs. To silence this warning, specify"
  50. " disable_custom_all_reduce=True explicitly.")
  51. return
  52. # test P2P capability
  53. # this is expensive to compute at the first time
  54. # then we cache the result
  55. if not _can_p2p(rank, world_size):
  56. logger.warning(
  57. "Custom allreduce is disabled because your platform lacks GPU P2P"
  58. " capability or P2P test failed. To silence this warning, specify"
  59. " disable_custom_all_reduce=True explicitly.")
  60. return
  61. _CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink)
  62. def begin_capture() -> None:
  63. global _IS_CAPTURING
  64. _IS_CAPTURING = True
  65. def end_capture() -> None:
  66. global _IS_CAPTURING
  67. _IS_CAPTURING = False
  68. def is_capturing() -> bool:
  69. return _IS_CAPTURING and _CA_HANDLE is not None
  70. def get_handle() -> Optional["CustomAllreduce"]:
  71. return _CA_HANDLE
  72. def is_initialized() -> bool:
  73. return _CA_HANDLE is not None
  74. @contextmanager
  75. def capture():
  76. try:
  77. begin_capture()
  78. yield
  79. finally:
  80. end_capture()
  81. handle = get_handle()
  82. if handle is not None:
  83. handle.register_graph_buffers()
  84. # pylint: disable=redefined-builtin
  85. def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
  86. ca_handle = get_handle()
  87. # when custom allreduce is disabled, this will be None
  88. if ca_handle is None:
  89. return
  90. if is_capturing():
  91. if torch.cuda.is_current_stream_capturing():
  92. if ca_handle.should_custom_ar(input):
  93. return ca_handle.all_reduce_reg(input)
  94. else:
  95. if ca_handle.should_custom_ar(input):
  96. # if warm up, mimic the allocation pattern
  97. # since custom allreduce is out-of-place
  98. return torch.empty_like(input)
  99. else:
  100. # NOTE: outside of cuda graph context,
  101. # custom allreduce incurs a cost of cudaMemcpy, which should
  102. # be small(<=1% of overall latency) compared to the performance
  103. # gains of using custom kernels
  104. if ca_handle.should_custom_ar(input):
  105. return ca_handle.all_reduce_unreg(input)
  106. @contextmanager
  107. def _nvml():
  108. try:
  109. pynvml.nvmlInit()
  110. yield
  111. finally:
  112. pynvml.nvmlShutdown()
  113. # query if the set of gpus are fully connected by nvlink (1 hop)
  114. @_nvml()
  115. def _is_full_nvlink(rank, world_size):
  116. handle = pynvml.nvmlDeviceGetHandleByIndex(rank)
  117. for i in range(world_size):
  118. if i != rank:
  119. try:
  120. peer_handle = pynvml.nvmlDeviceGetHandleByIndex(i)
  121. p2p_status = pynvml.nvmlDeviceGetP2PStatus(
  122. handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
  123. if p2p_status != pynvml.NVML_P2P_STATUS_OK:
  124. return False
  125. except pynvml.NVMLError as error:
  126. logger.info(
  127. f"NVLink detection failed with message \"{str(error)}\". "
  128. "This is normal if your machine has no NVLink equipped")
  129. return False
  130. return True
  131. def _can_p2p(rank: int, world_size: int) -> bool:
  132. for i in range(world_size):
  133. if i == rank:
  134. continue
  135. if not gpu_p2p_access_check(rank, i):
  136. return False
  137. return True
  138. class CustomAllreduce:
  139. # max_size: max supported allreduce size
  140. def __init__(self,
  141. rank,
  142. world_size,
  143. full_nvlink,
  144. max_size=8192 * 1024) -> None:
  145. # buffers memory are owned by this Python class and passed to C++
  146. # meta data composes of two parts: meta data for synchronization
  147. # (256 bytes) and a temporary buffer for storing intermediate
  148. # allreduce results.
  149. self.meta = torch.zeros(custom_ar.meta_size() + max_size,
  150. dtype=torch.uint8,
  151. device="cuda")
  152. # This is a pre-registered IPC buffer. In eager mode, input tensors
  153. # are first copied into this buffer before allreduce is performed
  154. self.buffer = torch.empty(max_size, dtype=torch.uint8, device="cuda")
  155. # This is a buffer for storing the tuples of pointers pointing to
  156. # IPC buffers from all ranks. Each registered tuple has size of
  157. # 8*world_size bytes where world_size is at most 8. Allocating 8MB
  158. # is enough for 131072 such tuples. The largest model I've seen only
  159. # needs less than 10000 of registered tuples.
  160. self.rank_data = torch.empty(8 * 1024 * 1024,
  161. dtype=torch.uint8,
  162. device="cuda")
  163. self.max_size = max_size
  164. self.world_size = world_size
  165. handles, offsets = self._get_ipc_meta(self.meta)
  166. self.full_nvlink = full_nvlink
  167. self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data,
  168. handles, offsets, rank,
  169. self.full_nvlink)
  170. self.register_buffer(self.buffer)
  171. def _get_ipc_meta(self, inp: torch.Tensor):
  172. # pylint: disable=protected-access
  173. data = inp.untyped_storage()._share_cuda_()
  174. shard_data = (
  175. data[1], # ipc handle to base ptr
  176. data[3], # offset of base ptr
  177. )
  178. return self._gather_ipc_meta(shard_data)
  179. def _gather_ipc_meta(self, shard_data):
  180. all_data = [None] * self.world_size
  181. dist.all_gather_object(all_data, shard_data)
  182. handles = []
  183. offsets = []
  184. for i in range(len(all_data)):
  185. handles.append(all_data[i][0])
  186. offsets.append(all_data[i][1])
  187. return handles, offsets
  188. def register_buffer(self, inp: torch.Tensor):
  189. handles, offsets = self._get_ipc_meta(inp)
  190. custom_ar.register_buffer(self._ptr, inp, handles, offsets)
  191. def register_graph_buffers(self):
  192. handle, offset = custom_ar.get_graph_buffer_ipc_meta(self._ptr)
  193. handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
  194. logger.info("Registering %d cuda graph addresses", len(offset))
  195. custom_ar.register_graph_buffers(self._ptr, handles, offsets)
  196. def should_custom_ar(self, inp: torch.Tensor):
  197. return custom_ar.should_custom_ar(inp, self.max_size, self.world_size,
  198. self.full_nvlink)
  199. # all reduce, assuming inp tensor is IPC registered with register_buffer,
  200. # or, in the context of cuda graphs, register_graph_buffers
  201. def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
  202. if out is None:
  203. out = torch.empty_like(inp)
  204. custom_ar.all_reduce_reg(self._ptr, inp, out)
  205. return out
  206. # all reduce, assuming inp tensor is NOT IPC registered
  207. def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
  208. if out is None:
  209. out = torch.empty_like(inp)
  210. custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out)
  211. return out
  212. def close(self):
  213. if self._ptr:
  214. custom_ar.dispose(self._ptr)
  215. self._ptr = 0
  216. def __del__(self):
  217. self.close()