custom_all_reduce_utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. import ctypes
  2. import json
  3. import os
  4. import pickle
  5. import subprocess
  6. import sys
  7. import tempfile
  8. from itertools import product
  9. from typing import Dict, Optional, Sequence
  10. import torch.distributed as dist
  11. import torch.multiprocessing as mp
  12. from loguru import logger
  13. import aphrodite.common.envs as envs
  14. from aphrodite.common.utils import (cuda_device_count_stateless,
  15. update_environment_variables)
  16. from aphrodite.distributed.device_communicators.cuda_wrapper import (
  17. CudaRTLibrary)
  18. def producer(batch_src: Sequence[int],
  19. producer_queue,
  20. consumer_queue,
  21. result_queue,
  22. cuda_visible_devices: Optional[str] = None):
  23. if cuda_visible_devices is not None:
  24. update_environment_variables(
  25. {"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
  26. lib = CudaRTLibrary()
  27. for i in batch_src:
  28. lib.cudaSetDevice(i)
  29. pointer = lib.cudaMalloc(1024)
  30. lib.cudaMemset(pointer, 1, 1024)
  31. lib.cudaDeviceSynchronize()
  32. handle = lib.cudaIpcGetMemHandle(pointer)
  33. producer_queue.put(handle)
  34. open_success = consumer_queue.get()
  35. if open_success:
  36. # use two queues to simulate barrier
  37. producer_queue.put(0)
  38. consumer_queue.get()
  39. # check if the memory is modified
  40. host_data = (ctypes.c_char * 1024)()
  41. lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
  42. for i in range(1024):
  43. if ord(host_data[i]) != 2:
  44. open_success = False
  45. break
  46. result_queue.put(open_success)
  47. lib.cudaDeviceReset()
  48. def consumer(batch_tgt: Sequence[int],
  49. producer_queue,
  50. consumer_queue,
  51. result_queue,
  52. cuda_visible_devices: Optional[str] = None):
  53. if cuda_visible_devices is not None:
  54. update_environment_variables(
  55. {"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
  56. lib = CudaRTLibrary()
  57. for j in batch_tgt:
  58. lib.cudaSetDevice(j)
  59. handle = producer_queue.get()
  60. open_success = False
  61. try:
  62. pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
  63. open_success = True
  64. except RuntimeError:
  65. # cannot error out here, because the producer process
  66. # is still waiting for the response.
  67. pass
  68. consumer_queue.put(open_success)
  69. if open_success:
  70. # modify the memory
  71. lib.cudaMemset(pointer, 2, 1024)
  72. lib.cudaDeviceSynchronize()
  73. # use two queues to simulate barrier
  74. producer_queue.get()
  75. consumer_queue.put(0)
  76. # check if the memory is modified
  77. host_data = (ctypes.c_char * 1024)()
  78. lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
  79. for i in range(1024):
  80. if ord(host_data[i]) != 2:
  81. open_success = False
  82. break
  83. result_queue.put(open_success)
  84. lib.cudaDeviceReset()
  85. def can_actually_p2p(
  86. batch_src: Sequence[int],
  87. batch_tgt: Sequence[int],
  88. ):
  89. """
  90. Usually, checking if P2P access is enabled can be done by
  91. `torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
  92. the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
  93. returns `True` even if P2P access is not actually possible.
  94. See
  95. https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
  96. Therefore, we have to perform a real P2P access to check if it is actually
  97. possible.
  98. Note on p2p and cuda IPC:
  99. Usually, one process uses one GPU:
  100. GPU src --> cuda context src --> tensor src --> process src
  101. We need to combine p2p and cuda IPC, so that:
  102. GPU src --> cuda context src --> tensor src --> process src
  103. |shared|
  104. GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
  105. That is to say, process src creates a tensor in GPU src, passes IPC handle to
  106. process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
  107. tensor in process tgt will be reflected in the tensor in process src, because
  108. they are the same memory segment.
  109. It is important to note that process tgt accesses the tensor in GPU tgt, not
  110. GPU src. That's why we need p2p access.
  111. The most time-consuming part is the process creation. To avoid creating
  112. processes for every pair of GPUs, we use batched testing. We create two
  113. processes for testing all pairs of GPUs in batch. The trick is to reset
  114. the device after each test (which is not available in PyTorch).
  115. """ # noqa
  116. cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
  117. # pass the CUDA_VISIBLE_DEVICES to the child process
  118. # to make sure they see the same set of GPUs
  119. # make sure the processes are spawned
  120. smp = mp.get_context("spawn")
  121. producer_queue = smp.Queue()
  122. consumer_queue = smp.Queue()
  123. result_queue = smp.Queue()
  124. p_src = smp.Process(target=producer,
  125. args=(batch_src, producer_queue, consumer_queue,
  126. result_queue, cuda_visible_devices))
  127. p_tgt = smp.Process(target=consumer,
  128. args=(batch_tgt, producer_queue, consumer_queue,
  129. result_queue, cuda_visible_devices))
  130. p_src.start()
  131. p_tgt.start()
  132. p_src.join()
  133. p_tgt.join()
  134. result = []
  135. for src, tgt in zip(batch_src, batch_tgt):
  136. a = result_queue.get()
  137. b = result_queue.get()
  138. if a != b:
  139. logger.warning("Two processes do not agree on the P2P access"
  140. f" status on {src} -> {tgt}, treat as disabled.")
  141. result.append(False)
  142. else:
  143. result.append(a)
  144. return result
  145. # why do we need this cache?
  146. # we are testing peer-to-peer (p2p) access between GPUs,across processes.
  147. # if we test it every time, it will be very slow, because we need to create
  148. # N * N * 2 processes, where N is the world size. This is very slow.
  149. # to reduce the time, we use a cache file to store the p2p access status.
  150. # the cache file is generated by the master process if it does not exist.
  151. # then all the processes can read the cache file to check the p2p access status.
  152. # Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
  153. # can have different cache files for different CUDA_VISIBLE_DEVICES settings,
  154. # e.g. used by different aphrodite engines. The device id in the cache file is
  155. # a **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
  156. # of visible devices in the aphrodite engine.
  157. _gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
  158. def gpu_p2p_access_check(src: int, tgt: int) -> bool:
  159. """Check if GPU src can access GPU tgt."""
  160. # if the cache variable is already calculated,
  161. # read from the cache instead of checking it again
  162. global _gpu_p2p_access_cache
  163. if _gpu_p2p_access_cache is not None:
  164. return _gpu_p2p_access_cache[f"{src}->{tgt}"]
  165. is_distributed = dist.is_initialized()
  166. num_dev = cuda_device_count_stateless()
  167. cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
  168. if cuda_visible_devices is None:
  169. cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
  170. path = os.path.join(
  171. envs.APHRODITE_CACHE_ROOT,
  172. f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json")
  173. os.makedirs(os.path.dirname(path), exist_ok=True)
  174. from aphrodite.distributed.parallel_state import get_world_group
  175. if ((not is_distributed or get_world_group().local_rank == 0)
  176. and (not os.path.exists(path))):
  177. # only the local master process (with local_rank == 0) can
  178. # enter this block to calculate the cache
  179. logger.info(f"generating GPU P2P access cache in {path}")
  180. cache = {}
  181. ids = list(range(num_dev))
  182. # batch of all pairs of GPUs
  183. batch_src, batch_tgt = zip(*list(product(ids, ids)))
  184. # NOTE: we use `subprocess` rather than `multiprocessing` here
  185. # because the caller might not have `if __name__ == "__main__":`,
  186. # in that case we cannot use spawn method in multiprocessing.
  187. # However, `can_actually_p2p` requires spawn method.
  188. # The fix is, we use `subprocess` to call the function,
  189. # where we have `if __name__ == "__main__":` in this file.
  190. # use a temporary file to store the result
  191. # we don't use the output of the subprocess directly,
  192. # because the subprocess might produce logging output
  193. with tempfile.NamedTemporaryFile() as output_file:
  194. input_bytes = pickle.dumps(
  195. (batch_src, batch_tgt, output_file.name))
  196. returned = subprocess.run([sys.executable, __file__],
  197. input=input_bytes,
  198. capture_output=True)
  199. # check if the subprocess is successful
  200. try:
  201. returned.check_returncode()
  202. except Exception as e:
  203. # wrap raised exception to provide more information
  204. raise RuntimeError(
  205. f"Error happened when batch testing "
  206. f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
  207. f"{returned.stderr.decode()}") from e
  208. with open(output_file.name, "rb") as f:
  209. result = pickle.load(f)
  210. for _i, _j, r in zip(batch_src, batch_tgt, result):
  211. cache[f"{_i}->{_j}"] = r
  212. with open(path, "w") as f:
  213. json.dump(cache, f, indent=4)
  214. if is_distributed:
  215. get_world_group().barrier()
  216. logger.debug(f"reading GPU P2P access cache from {path}")
  217. with open(path, "r") as f:
  218. cache = json.load(f)
  219. _gpu_p2p_access_cache = cache
  220. return _gpu_p2p_access_cache[f"{src}->{tgt}"]
  221. __all__ = ["gpu_p2p_access_check"]
  222. if __name__ == "__main__":
  223. batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())
  224. result = can_actually_p2p(batch_src, batch_tgt)
  225. with open(output_file, "wb") as f:
  226. f.write(pickle.dumps(result))