custom_all_reduce_utils.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. import ctypes
  2. import json
  3. import os
  4. import pickle
  5. import subprocess
  6. import sys
  7. from itertools import product
  8. from typing import Dict, Optional, Sequence
  9. import torch.distributed as dist
  10. import torch.multiprocessing as mp
  11. from loguru import logger
  12. from aphrodite.common.utils import (cuda_device_count_stateless,
  13. update_environment_variables)
  14. from aphrodite.distributed.device_communicators.cuda_wrapper import (
  15. CudaRTLibrary)
  16. def producer(batch_src: Sequence[int],
  17. producer_queue,
  18. consumer_queue,
  19. result_queue,
  20. cuda_visible_devices: Optional[str] = None):
  21. if cuda_visible_devices is not None:
  22. update_environment_variables(
  23. {"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
  24. lib = CudaRTLibrary()
  25. for i in batch_src:
  26. lib.cudaSetDevice(i)
  27. pointer = lib.cudaMalloc(1024)
  28. lib.cudaMemset(pointer, 1, 1024)
  29. lib.cudaDeviceSynchronize()
  30. handle = lib.cudaIpcGetMemHandle(pointer)
  31. producer_queue.put(handle)
  32. open_success = consumer_queue.get()
  33. if open_success:
  34. # use two queues to simulate barrier
  35. producer_queue.put(0)
  36. consumer_queue.get()
  37. # check if the memory is modified
  38. host_data = (ctypes.c_char * 1024)()
  39. lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
  40. for i in range(1024):
  41. if ord(host_data[i]) != 2:
  42. open_success = False
  43. break
  44. result_queue.put(open_success)
  45. lib.cudaDeviceReset()
  46. def consumer(batch_tgt: Sequence[int],
  47. producer_queue,
  48. consumer_queue,
  49. result_queue,
  50. cuda_visible_devices: Optional[str] = None):
  51. if cuda_visible_devices is not None:
  52. update_environment_variables(
  53. {"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
  54. lib = CudaRTLibrary()
  55. for j in batch_tgt:
  56. lib.cudaSetDevice(j)
  57. handle = producer_queue.get()
  58. open_success = False
  59. try:
  60. pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
  61. open_success = True
  62. except RuntimeError:
  63. # cannot error out here, because the producer process
  64. # is still waiting for the response.
  65. pass
  66. consumer_queue.put(open_success)
  67. if open_success:
  68. # modify the memory
  69. lib.cudaMemset(pointer, 2, 1024)
  70. lib.cudaDeviceSynchronize()
  71. # use two queues to simulate barrier
  72. producer_queue.get()
  73. consumer_queue.put(0)
  74. # check if the memory is modified
  75. host_data = (ctypes.c_char * 1024)()
  76. lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
  77. for i in range(1024):
  78. if ord(host_data[i]) != 2:
  79. open_success = False
  80. break
  81. result_queue.put(open_success)
  82. lib.cudaDeviceReset()
  83. def can_actually_p2p(
  84. batch_src: Sequence[int],
  85. batch_tgt: Sequence[int],
  86. ):
  87. """
  88. Usually, checking if P2P access is enabled can be done by
  89. `torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
  90. the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
  91. returns `True` even if P2P access is not actually possible.
  92. See
  93. https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
  94. Therefore, we have to perform a real P2P access to check if it is actually
  95. possible.
  96. Note on p2p and cuda IPC:
  97. Usually, one process uses one GPU:
  98. GPU src --> cuda context src --> tensor src --> process src
  99. We need to combine p2p and cuda IPC, so that:
  100. GPU src --> cuda context src --> tensor src --> process src
  101. |shared|
  102. GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
  103. That is to say, process src creates a tensor in GPU src, passes IPC handle to
  104. process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
  105. tensor in process tgt will be reflected in the tensor in process src, because
  106. they are the same memory segment.
  107. It is important to note that process tgt accesses the tensor in GPU tgt, not
  108. GPU src. That's why we need p2p access.
  109. The most time-consuming part is the process creation. To avoid creating
  110. processes for every pair of GPUs, we use batched testing. We create two
  111. processes for testing all pairs of GPUs in batch. The trick is to reset
  112. the device after each test (which is not available in PyTorch).
  113. """ # noqa
  114. cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
  115. # pass the CUDA_VISIBLE_DEVICES to the child process
  116. # to make sure they see the same set of GPUs
  117. # make sure the processes are spawned
  118. smp = mp.get_context("spawn")
  119. producer_queue = smp.Queue()
  120. consumer_queue = smp.Queue()
  121. result_queue = smp.Queue()
  122. p_src = smp.Process(target=producer,
  123. args=(batch_src, producer_queue, consumer_queue,
  124. result_queue, cuda_visible_devices))
  125. p_tgt = smp.Process(target=consumer,
  126. args=(batch_tgt, producer_queue, consumer_queue,
  127. result_queue, cuda_visible_devices))
  128. p_src.start()
  129. p_tgt.start()
  130. p_src.join()
  131. p_tgt.join()
  132. result = []
  133. for src, tgt in zip(batch_src, batch_tgt):
  134. a = result_queue.get()
  135. b = result_queue.get()
  136. if a != b:
  137. logger.warning("Two processes do not agree on the P2P access"
  138. f" status on {src} -> {tgt}, treat as disabled.")
  139. result.append(False)
  140. else:
  141. result.append(a)
  142. return result
  143. # why do we need this cache?
  144. # we are testing peer-to-peer (p2p) access between GPUs,across processes.
  145. # if we test it every time, it will be very slow, because we need to create
  146. # N * N * 2 processes, where N is the world size. This is very slow.
  147. # to reduce the time, we use a cache file to store the p2p access status.
  148. # the cache file is generated by the master process if it does not exist.
  149. # then all the processes can read the cache file to check the p2p access status.
  150. # Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
  151. # can have different cache files for different CUDA_VISIBLE_DEVICES settings,
  152. # e.g. used by different aphrodite engines. The device id in the cache file is
  153. # a **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
  154. # of visible devices in the aphrodite engine.
  155. _gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
  156. def gpu_p2p_access_check(src: int, tgt: int) -> bool:
  157. """Check if GPU src can access GPU tgt."""
  158. # if the cache variable is already calculated,
  159. # read from the cache instead of checking it again
  160. global _gpu_p2p_access_cache
  161. if _gpu_p2p_access_cache is not None:
  162. return _gpu_p2p_access_cache[f"{src}->{tgt}"]
  163. is_distributed = dist.is_initialized()
  164. num_dev = cuda_device_count_stateless()
  165. cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
  166. if cuda_visible_devices is None:
  167. cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
  168. APHRODITE_CONFIG_ROOT = os.getenv("APHRODITE_CONFIG_ROOT", "~/.config")
  169. path = os.path.expanduser(
  170. f"{APHRODITE_CONFIG_ROOT}/aphrodite/gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
  171. )
  172. os.makedirs(os.path.dirname(path), exist_ok=True)
  173. from aphrodite.distributed.parallel_state import get_world_group
  174. if ((not is_distributed or get_world_group().local_rank == 0)
  175. and (not os.path.exists(path))):
  176. # only the local master process (with local_rank == 0) can
  177. # enter this block to calculate the cache
  178. logger.info(f"generating GPU P2P access cache in {path}")
  179. cache = {}
  180. ids = list(range(num_dev))
  181. # batch of all pairs of GPUs
  182. batch_src, batch_tgt = zip(*list(product(ids, ids)))
  183. # NOTE: we use `subprocess` rather than `multiprocessing` here
  184. # because the caller might not have `if __name__ == "__main__":`,
  185. # in that case we cannot use spawn method in multiprocessing.
  186. # However, `can_actually_p2p` requires spawn method.
  187. # The fix is, we use `subprocess` to call the function,
  188. # where we have `if __name__ == "__main__":` in this file.
  189. input_bytes = pickle.dumps((batch_src, batch_tgt))
  190. returned = subprocess.run([sys.executable, __file__],
  191. input=input_bytes,
  192. capture_output=True)
  193. # check if the subprocess is successful
  194. try:
  195. returned.check_returncode()
  196. except Exception as e:
  197. # wrap raised exception to provide more information
  198. raise RuntimeError(
  199. f"Error happened when batch testing "
  200. f"peer-to-peer access from {batch_src} to {batch_tgt}") from e
  201. result = pickle.loads(returned.stdout)
  202. for _i, _j, r in zip(batch_src, batch_tgt, result):
  203. cache[f"{_i}->{_j}"] = r
  204. with open(path, "w") as f:
  205. json.dump(cache, f, indent=4)
  206. if is_distributed:
  207. get_world_group().barrier()
  208. logger.debug(f"reading GPU P2P access cache from {path}")
  209. with open(path, "r") as f:
  210. cache = json.load(f)
  211. _gpu_p2p_access_cache = cache
  212. return _gpu_p2p_access_cache[f"{src}->{tgt}"]
  213. __all__ = ["gpu_p2p_access_check"]
  214. if __name__ == "__main__":
  215. batch_src, batch_tgt = pickle.loads(sys.stdin.buffer.read())
  216. result = can_actually_p2p(batch_src, batch_tgt)
  217. sys.stdout.buffer.write(pickle.dumps(result))