custom_all_reduce.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. from contextlib import contextmanager
  2. from typing import Optional
  3. from loguru import logger
  4. import torch
  5. import torch.distributed as dist
  6. from aphrodite.distributed import gpu_p2p_access_check
  7. try:
  8. from aphrodite._C import custom_ar
  9. import pynvml
  10. except ImportError:
  11. # For AMD GPUs
  12. custom_ar = None
  13. pynvml = None
  14. _CA_HANDLE = None
  15. _IS_CAPTURING = False
  16. _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
  17. def init_custom_ar() -> None:
  18. from aphrodite.distributed import (
  19. get_tensor_model_parallel_rank,
  20. get_tensor_model_parallel_world_size)
  21. global _CA_HANDLE
  22. if _CA_HANDLE is not None:
  23. return
  24. rank = get_tensor_model_parallel_rank()
  25. world_size = get_tensor_model_parallel_world_size()
  26. if world_size == 1:
  27. return
  28. if world_size not in _SUPPORTED_WORLD_SIZES:
  29. logger.warning(
  30. "Custom allreduce is disabled due to an unsupported world size: "
  31. "%d. Supported world sizes: %s. To silence this warning, specify"
  32. " disable_custom_all_reduce=True explicitly.", world_size,
  33. str(_SUPPORTED_WORLD_SIZES))
  34. return
  35. num_dev = torch.cuda.device_count()
  36. # note: num dev can be larger than world_size if we're only using
  37. # first few GPUs
  38. if num_dev < world_size:
  39. logger.warning(
  40. "Cannot test GPU P2P because not all GPUs are visible to the "
  41. "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
  42. " is set.")
  43. return False
  44. # test nvlink first, this will filter out most of the cases
  45. # where custom allreduce is not supported
  46. full_nvlink = _is_full_nvlink(rank, world_size)
  47. if world_size > 2 and not full_nvlink:
  48. logger.warning(
  49. "Custom allreduce is disabled because it's not supported on more"
  50. " than two PCIe-only GPUs. To silence this warning, specify"
  51. " disable_custom_all_reduce=True explicitly.")
  52. return
  53. # test P2P capability
  54. # this is expensive to compute at the first time
  55. # then we cache the result
  56. if not _can_p2p(rank, world_size):
  57. logger.warning(
  58. "Custom allreduce is disabled because your platform lacks GPU P2P"
  59. " capability or P2P test failed. To silence this warning, specify"
  60. " disable_custom_all_reduce=True explicitly.")
  61. return
  62. _CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink)
  63. def begin_capture() -> None:
  64. global _IS_CAPTURING
  65. _IS_CAPTURING = True
  66. def end_capture() -> None:
  67. global _IS_CAPTURING
  68. _IS_CAPTURING = False
  69. def is_capturing() -> bool:
  70. return _IS_CAPTURING and _CA_HANDLE is not None
  71. def get_handle() -> Optional["CustomAllreduce"]:
  72. return _CA_HANDLE
  73. def is_initialized() -> bool:
  74. return _CA_HANDLE is not None
  75. @contextmanager
  76. def capture():
  77. try:
  78. begin_capture()
  79. yield
  80. finally:
  81. end_capture()
  82. handle = get_handle()
  83. if handle is not None:
  84. handle.register_graph_buffers()
  85. # pylint: disable=redefined-builtin
  86. def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
  87. ca_handle = get_handle()
  88. # when custom allreduce is disabled, this will be None
  89. if ca_handle is None:
  90. return
  91. if is_capturing():
  92. if torch.cuda.is_current_stream_capturing():
  93. if ca_handle.should_custom_ar(input):
  94. return ca_handle.all_reduce_reg(input)
  95. else:
  96. if ca_handle.should_custom_ar(input):
  97. # if warm up, mimic the allocation pattern
  98. # since custom allreduce is out-of-place
  99. return torch.empty_like(input)
  100. else:
  101. # NOTE: outside of cuda graph context,
  102. # custom allreduce incurs a cost of cudaMemcpy, which should
  103. # be small(<=1% of overall latency) compared to the performance
  104. # gains of using custom kernels
  105. if ca_handle.should_custom_ar(input):
  106. return ca_handle.all_reduce_unreg(input)
  107. @contextmanager
  108. def _nvml():
  109. try:
  110. pynvml.nvmlInit()
  111. yield
  112. finally:
  113. pynvml.nvmlShutdown()
  114. # query if the set of gpus are fully connected by nvlink (1 hop)
  115. @_nvml()
  116. def _is_full_nvlink(rank, world_size):
  117. handle = pynvml.nvmlDeviceGetHandleByIndex(rank)
  118. for i in range(world_size):
  119. if i != rank:
  120. try:
  121. link_state = pynvml.nvmlDeviceGetNvLinkState(handle, i)
  122. if not link_state:
  123. return False
  124. except pynvml.NVMLError as error:
  125. logger.info(
  126. f"NVLink detection failed with message \"{str(error)}\". "
  127. "This is normal if your machine has no NVLink equipped")
  128. return False
  129. return True
  130. def _can_p2p(rank: int, world_size: int) -> bool:
  131. for i in range(world_size):
  132. if i == rank:
  133. continue
  134. if not gpu_p2p_access_check(rank, i):
  135. return False
  136. return True
  137. class CustomAllreduce:
  138. # max_size: max supported allreduce size
  139. def __init__(self,
  140. rank,
  141. world_size,
  142. full_nvlink,
  143. max_size=8192 * 1024) -> None:
  144. # buffers memory are owned by this Python class and passed to C++
  145. # meta data composes of two parts: meta data for synchronization
  146. # (256 bytes) and a temporary buffer for storing intermediate
  147. # allreduce results.
  148. self.meta = torch.zeros(custom_ar.meta_size() + max_size,
  149. dtype=torch.uint8,
  150. device="cuda")
  151. # This is a pre-registered IPC buffer. In eager mode, input tensors
  152. # are first copied into this buffer before allreduce is performed
  153. self.buffer = torch.empty(max_size, dtype=torch.uint8, device="cuda")
  154. # This is a buffer for storing the tuples of pointers pointing to
  155. # IPC buffers from all ranks. Each registered tuple has size of
  156. # 8*world_size bytes where world_size is at most 8. Allocating 8MB
  157. # is enough for 131072 such tuples. The largest model I've seen only
  158. # needs less than 10000 of registered tuples.
  159. self.rank_data = torch.empty(8 * 1024 * 1024,
  160. dtype=torch.uint8,
  161. device="cuda")
  162. self.max_size = max_size
  163. self.world_size = world_size
  164. handles, offsets = self._get_ipc_meta(self.meta)
  165. self.full_nvlink = full_nvlink
  166. self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data,
  167. handles, offsets, rank,
  168. self.full_nvlink)
  169. self.register_buffer(self.buffer)
  170. def _get_ipc_meta(self, inp: torch.Tensor):
  171. # pylint: disable=protected-access
  172. data = inp.untyped_storage()._share_cuda_()
  173. shard_data = (
  174. data[1], # ipc handle to base ptr
  175. data[3], # offset of base ptr
  176. )
  177. return self._gather_ipc_meta(shard_data)
  178. def _gather_ipc_meta(self, shard_data):
  179. all_data = [None] * self.world_size
  180. dist.all_gather_object(all_data, shard_data)
  181. handles = []
  182. offsets = []
  183. for i in range(len(all_data)):
  184. handles.append(all_data[i][0])
  185. offsets.append(all_data[i][1])
  186. return handles, offsets
  187. def register_buffer(self, inp: torch.Tensor):
  188. handles, offsets = self._get_ipc_meta(inp)
  189. custom_ar.register_buffer(self._ptr, inp, handles, offsets)
  190. def register_graph_buffers(self):
  191. handle, offset = custom_ar.get_graph_buffer_ipc_meta(self._ptr)
  192. handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
  193. logger.info("Registering %d cuda graph addresses", len(offset))
  194. custom_ar.register_graph_buffers(self._ptr, handles, offsets)
  195. def should_custom_ar(self, inp: torch.Tensor):
  196. return custom_ar.should_custom_ar(inp, self.max_size, self.world_size,
  197. self.full_nvlink)
  198. # all reduce, assuming inp tensor is IPC registered with register_buffer,
  199. # or, in the context of cuda graphs, register_graph_buffers
  200. def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
  201. if out is None:
  202. out = torch.empty_like(inp)
  203. custom_ar.all_reduce_reg(self._ptr, inp, out)
  204. return out
  205. # all reduce, assuming inp tensor is NOT IPC registered
  206. def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
  207. if out is None:
  208. out = torch.empty_like(inp)
  209. custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out)
  210. return out
  211. def close(self):
  212. if self._ptr:
  213. custom_ar.dispose(self._ptr)
  214. self._ptr = 0
  215. def __del__(self):
  216. self.close()