custom_all_reduce_utils.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import ctypes
  2. import json
  3. import os
  4. import pickle
  5. import subprocess
  6. import sys
  7. from itertools import product
  8. from typing import Dict, Optional, Sequence
  9. import torch.distributed as dist
  10. import torch.multiprocessing as mp
  11. from loguru import logger
  12. import aphrodite.common.envs as envs
  13. from aphrodite.common.utils import (cuda_device_count_stateless,
  14. update_environment_variables)
  15. from aphrodite.distributed.device_communicators.cuda_wrapper import (
  16. CudaRTLibrary)
  17. def producer(batch_src: Sequence[int],
  18. producer_queue,
  19. consumer_queue,
  20. result_queue,
  21. cuda_visible_devices: Optional[str] = None):
  22. if cuda_visible_devices is not None:
  23. update_environment_variables(
  24. {"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
  25. lib = CudaRTLibrary()
  26. for i in batch_src:
  27. lib.cudaSetDevice(i)
  28. pointer = lib.cudaMalloc(1024)
  29. lib.cudaMemset(pointer, 1, 1024)
  30. lib.cudaDeviceSynchronize()
  31. handle = lib.cudaIpcGetMemHandle(pointer)
  32. producer_queue.put(handle)
  33. open_success = consumer_queue.get()
  34. if open_success:
  35. # use two queues to simulate barrier
  36. producer_queue.put(0)
  37. consumer_queue.get()
  38. # check if the memory is modified
  39. host_data = (ctypes.c_char * 1024)()
  40. lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
  41. for i in range(1024):
  42. if ord(host_data[i]) != 2:
  43. open_success = False
  44. break
  45. result_queue.put(open_success)
  46. lib.cudaDeviceReset()
  47. def consumer(batch_tgt: Sequence[int],
  48. producer_queue,
  49. consumer_queue,
  50. result_queue,
  51. cuda_visible_devices: Optional[str] = None):
  52. if cuda_visible_devices is not None:
  53. update_environment_variables(
  54. {"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
  55. lib = CudaRTLibrary()
  56. for j in batch_tgt:
  57. lib.cudaSetDevice(j)
  58. handle = producer_queue.get()
  59. open_success = False
  60. try:
  61. pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
  62. open_success = True
  63. except RuntimeError:
  64. # cannot error out here, because the producer process
  65. # is still waiting for the response.
  66. pass
  67. consumer_queue.put(open_success)
  68. if open_success:
  69. # modify the memory
  70. lib.cudaMemset(pointer, 2, 1024)
  71. lib.cudaDeviceSynchronize()
  72. # use two queues to simulate barrier
  73. producer_queue.get()
  74. consumer_queue.put(0)
  75. # check if the memory is modified
  76. host_data = (ctypes.c_char * 1024)()
  77. lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
  78. for i in range(1024):
  79. if ord(host_data[i]) != 2:
  80. open_success = False
  81. break
  82. result_queue.put(open_success)
  83. lib.cudaDeviceReset()
  84. def can_actually_p2p(
  85. batch_src: Sequence[int],
  86. batch_tgt: Sequence[int],
  87. ):
  88. """
  89. Usually, checking if P2P access is enabled can be done by
  90. `torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
  91. the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
  92. returns `True` even if P2P access is not actually possible.
  93. See
  94. https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
  95. Therefore, we have to perform a real P2P access to check if it is actually
  96. possible.
  97. Note on p2p and cuda IPC:
  98. Usually, one process uses one GPU:
  99. GPU src --> cuda context src --> tensor src --> process src
  100. We need to combine p2p and cuda IPC, so that:
  101. GPU src --> cuda context src --> tensor src --> process src
  102. |shared|
  103. GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
  104. That is to say, process src creates a tensor in GPU src, passes IPC handle to
  105. process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
  106. tensor in process tgt will be reflected in the tensor in process src, because
  107. they are the same memory segment.
  108. It is important to note that process tgt accesses the tensor in GPU tgt, not
  109. GPU src. That's why we need p2p access.
  110. The most time-consuming part is the process creation. To avoid creating
  111. processes for every pair of GPUs, we use batched testing. We create two
  112. processes for testing all pairs of GPUs in batch. The trick is to reset
  113. the device after each test (which is not available in PyTorch).
  114. """ # noqa
  115. cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
  116. # pass the CUDA_VISIBLE_DEVICES to the child process
  117. # to make sure they see the same set of GPUs
  118. # make sure the processes are spawned
  119. smp = mp.get_context("spawn")
  120. producer_queue = smp.Queue()
  121. consumer_queue = smp.Queue()
  122. result_queue = smp.Queue()
  123. p_src = smp.Process(target=producer,
  124. args=(batch_src, producer_queue, consumer_queue,
  125. result_queue, cuda_visible_devices))
  126. p_tgt = smp.Process(target=consumer,
  127. args=(batch_tgt, producer_queue, consumer_queue,
  128. result_queue, cuda_visible_devices))
  129. p_src.start()
  130. p_tgt.start()
  131. p_src.join()
  132. p_tgt.join()
  133. result = []
  134. for src, tgt in zip(batch_src, batch_tgt):
  135. a = result_queue.get()
  136. b = result_queue.get()
  137. if a != b:
  138. logger.warning("Two processes do not agree on the P2P access"
  139. f" status on {src} -> {tgt}, treat as disabled.")
  140. result.append(False)
  141. else:
  142. result.append(a)
  143. return result
  144. # why do we need this cache?
  145. # we are testing peer-to-peer (p2p) access between GPUs,across processes.
  146. # if we test it every time, it will be very slow, because we need to create
  147. # N * N * 2 processes, where N is the world size. This is very slow.
  148. # to reduce the time, we use a cache file to store the p2p access status.
  149. # the cache file is generated by the master process if it does not exist.
  150. # then all the processes can read the cache file to check the p2p access status.
  151. # Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
  152. # can have different cache files for different CUDA_VISIBLE_DEVICES settings,
  153. # e.g. used by different aphrodite engines. The device id in the cache file is
  154. # a **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
  155. # of visible devices in the aphrodite engine.
  156. _gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
  157. def gpu_p2p_access_check(src: int, tgt: int) -> bool:
  158. """Check if GPU src can access GPU tgt."""
  159. # if the cache variable is already calculated,
  160. # read from the cache instead of checking it again
  161. global _gpu_p2p_access_cache
  162. if _gpu_p2p_access_cache is not None:
  163. return _gpu_p2p_access_cache[f"{src}->{tgt}"]
  164. is_distributed = dist.is_initialized()
  165. num_dev = cuda_device_count_stateless()
  166. cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
  167. if cuda_visible_devices is None:
  168. cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
  169. path = os.path.join(
  170. envs.APHRODITE_CACHE_ROOT,
  171. f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json")
  172. os.makedirs(os.path.dirname(path), exist_ok=True)
  173. from aphrodite.distributed.parallel_state import get_world_group
  174. if ((not is_distributed or get_world_group().local_rank == 0)
  175. and (not os.path.exists(path))):
  176. # only the local master process (with local_rank == 0) can
  177. # enter this block to calculate the cache
  178. logger.info(f"generating GPU P2P access cache in {path}")
  179. cache = {}
  180. ids = list(range(num_dev))
  181. # batch of all pairs of GPUs
  182. batch_src, batch_tgt = zip(*list(product(ids, ids)))
  183. # NOTE: we use `subprocess` rather than `multiprocessing` here
  184. # because the caller might not have `if __name__ == "__main__":`,
  185. # in that case we cannot use spawn method in multiprocessing.
  186. # However, `can_actually_p2p` requires spawn method.
  187. # The fix is, we use `subprocess` to call the function,
  188. # where we have `if __name__ == "__main__":` in this file.
  189. input_bytes = pickle.dumps((batch_src, batch_tgt))
  190. returned = subprocess.run([sys.executable, __file__],
  191. input=input_bytes,
  192. capture_output=True)
  193. # check if the subprocess is successful
  194. try:
  195. returned.check_returncode()
  196. except Exception as e:
  197. # wrap raised exception to provide more information
  198. raise RuntimeError(
  199. f"Error happened when batch testing "
  200. f"peer-to-peer access from {batch_src} to {batch_tgt}") from e
  201. result = pickle.loads(returned.stdout)
  202. for _i, _j, r in zip(batch_src, batch_tgt, result):
  203. cache[f"{_i}->{_j}"] = r
  204. with open(path, "w") as f:
  205. json.dump(cache, f, indent=4)
  206. if is_distributed:
  207. get_world_group().barrier()
  208. logger.debug(f"reading GPU P2P access cache from {path}")
  209. with open(path, "r") as f:
  210. cache = json.load(f)
  211. _gpu_p2p_access_cache = cache
  212. return _gpu_p2p_access_cache[f"{src}->{tgt}"]
  213. __all__ = ["gpu_p2p_access_check"]
  214. if __name__ == "__main__":
  215. batch_src, batch_tgt = pickle.loads(sys.stdin.buffer.read())
  216. result = can_actually_p2p(batch_src, batch_tgt)
  217. sys.stdout.buffer.write(pickle.dumps(result))