multiproc_gpu_executor.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. import asyncio
  2. import os
  3. import signal
  4. import threading
  5. import weakref
  6. from functools import partial
  7. from typing import Any, List, Optional
  8. import torch
  9. from loguru import logger
  10. from aphrodite.common.sequence import ExecuteModelRequest
  11. from aphrodite.common.utils import (_run_task_with_lock,
  12. cuda_device_count_stateless,
  13. get_aphrodite_instance_id,
  14. get_distributed_init_method, get_open_port,
  15. make_async, update_environment_variables)
  16. from aphrodite.executor.distributed_gpu_executor import ( # yapf: disable
  17. DistributedGPUExecutor, DistributedGPUExecutorAsync)
  18. from aphrodite.executor.gpu_executor import create_worker
  19. from aphrodite.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
  20. ResultHandler,
  21. WorkerMonitor)
  22. from aphrodite.modeling.layers.sampler import SamplerOutput
  23. from aphrodite.triton_utils import maybe_set_triton_cache_manager
  24. class MultiprocessingGPUExecutor(DistributedGPUExecutor):
  25. """Python multiprocessing-based multi-GPU executor"""
  26. uses_ray: bool = False
  27. def _init_executor(self) -> None:
  28. self._check_executor_parameters()
  29. # Create the parallel GPU workers.
  30. world_size = self.parallel_config.world_size
  31. tensor_parallel_size = self.parallel_config.tensor_parallel_size
  32. # Ensure that APHRODITE_INSTANCE_ID is set, to be inherited by workers
  33. os.environ["APHRODITE_INSTANCE_ID"] = get_aphrodite_instance_id()
  34. # Disable torch async compiling which won't work with daemonic processes
  35. os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
  36. # Configure thread parallelism if OMP_NUM_THREADS isn't set
  37. #
  38. # Helps to avoid CPU contention. The default of spawning a thread per
  39. # core combined with multiprocessing for each GPU can have a negative
  40. # impact on performance. The contention is amplified when running in a
  41. # container where CPU limits can cause throttling.
  42. default_omp_num_threads = 1
  43. if "OMP_NUM_THREADS" not in os.environ and (
  44. current_parallelism :=
  45. torch.get_num_threads()) > default_omp_num_threads:
  46. logger.warning(
  47. f"Reducing Torch parallelism from {current_parallelism} "
  48. f"threads to {default_omp_num_threads} to avoid "
  49. "unnecessary CPU contention. Set OMP_NUM_THREADS in the "
  50. "external environment to tune this value as needed.")
  51. os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
  52. torch.set_num_threads(default_omp_num_threads)
  53. if world_size > 1:
  54. maybe_set_triton_cache_manager()
  55. # Multiprocessing-based executor does not support multi-node setting.
  56. # Since it only works for single node, we can use the loopback address
  57. # 127.0.0.1 for communication.
  58. distributed_init_method = get_distributed_init_method(
  59. "127.0.0.1", get_open_port())
  60. self.workers: List[ProcessWorkerWrapper] = []
  61. # This is the list of workers that are rank 0 of each TP group EXCEPT
  62. # global rank 0. These are the workers that will broadcast to the
  63. # rest of the workers.
  64. self.tp_driver_workers: List[ProcessWorkerWrapper] = []
  65. # This is the list of workers that are not drivers and not the first
  66. # worker in a TP group. These are the workers that will be
  67. # broadcasted to.
  68. self.non_driver_workers: List[ProcessWorkerWrapper] = []
  69. if world_size == 1:
  70. self.worker_monitor = None
  71. else:
  72. result_handler = ResultHandler()
  73. for rank in range(1, world_size):
  74. worker = ProcessWorkerWrapper(
  75. result_handler,
  76. partial(
  77. create_worker,
  78. **self._get_create_worker_kwargs(
  79. rank=rank,
  80. local_rank=rank,
  81. distributed_init_method=distributed_init_method,
  82. )))
  83. self.workers.append(worker)
  84. if rank % tensor_parallel_size == 0:
  85. self.tp_driver_workers.append(worker)
  86. else:
  87. self.non_driver_workers.append(worker)
  88. self.worker_monitor = WorkerMonitor(self.workers, result_handler)
  89. result_handler.start()
  90. self.worker_monitor.start()
  91. # Set up signal handlers to shutdown the executor cleanly
  92. # sometimes gc does not work well
  93. # Use weakref to avoid holding a reference to self
  94. ref = weakref.ref(self)
  95. def shutdown(signum, frame):
  96. if executor := ref():
  97. executor.shutdown()
  98. if threading.current_thread() is threading.main_thread():
  99. signal.signal(signal.SIGINT, shutdown)
  100. signal.signal(signal.SIGTERM, shutdown)
  101. self.driver_worker = self._create_worker(
  102. distributed_init_method=distributed_init_method)
  103. self._run_workers("init_device")
  104. self._run_workers("load_model",
  105. max_concurrent_workers=self.parallel_config.
  106. max_parallel_loading_workers)
  107. def _check_executor_parameters(self):
  108. world_size = self.parallel_config.world_size
  109. tensor_parallel_size = self.parallel_config.tensor_parallel_size
  110. # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
  111. if "CUDA_VISIBLE_DEVICES" not in os.environ:
  112. update_environment_variables({
  113. "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
  114. })
  115. cuda_device_count = cuda_device_count_stateless()
  116. # Use confusing message for more common TP-only case.
  117. assert tensor_parallel_size <= cuda_device_count, (
  118. f"please set tensor_parallel_size ({tensor_parallel_size}) "
  119. f"to less than max local gpu count ({cuda_device_count})")
  120. assert world_size <= cuda_device_count, (
  121. f"please ensure that world_size ({world_size}) "
  122. f"is less than than max local gpu count ({cuda_device_count})")
  123. def shutdown(self):
  124. if (worker_monitor := getattr(self, "worker_monitor",
  125. None)) is not None:
  126. worker_monitor.close()
  127. def _driver_execute_model(
  128. self, execute_model_req: Optional[ExecuteModelRequest]
  129. ) -> Optional[List[SamplerOutput]]:
  130. """Run execute_model in the driver worker.
  131. Passing None will cause the driver to stop the model execution
  132. loop running in each of the remote workers.
  133. """
  134. return self.driver_worker.execute_model(execute_model_req)
  135. def _run_workers(
  136. self,
  137. method: str,
  138. *args,
  139. async_run_tensor_parallel_workers_only: bool = False,
  140. max_concurrent_workers: Optional[int] = None,
  141. **kwargs,
  142. ) -> Any:
  143. """Runs the given method on all workers.
  144. Args:
  145. async_run_tensor_parallel_workers_only: If True the method will be
  146. run only in the remote TP workers, not the driver worker.
  147. It will also be run asynchronously and return a list of futures
  148. rather than blocking on the results.
  149. """
  150. if max_concurrent_workers:
  151. raise NotImplementedError(
  152. "max_concurrent_workers is not supported yet.")
  153. if async_run_tensor_parallel_workers_only:
  154. # Run only non-driver workers and just return futures.
  155. return [
  156. worker.execute_method(method, *args, **kwargs)
  157. for worker in self.non_driver_workers
  158. ]
  159. # Start all remote workers first.
  160. worker_outputs = [
  161. worker.execute_method(method, *args, **kwargs)
  162. for worker in self.workers
  163. ]
  164. driver_worker_method = getattr(self.driver_worker, method)
  165. driver_worker_output = driver_worker_method(*args, **kwargs)
  166. # Get the results of the workers.
  167. return [driver_worker_output
  168. ] + [output.get() for output in worker_outputs]
  169. def check_health(self) -> None:
  170. """Raises an error if engine is unhealthy."""
  171. if self.worker_monitor is not None and not self.worker_monitor.is_alive(
  172. ):
  173. raise RuntimeError("Worker processes are not running")
  174. def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
  175. """Wait for futures returned from _run_workers() with
  176. async_run_remote_workers_only to complete."""
  177. for result in parallel_worker_tasks:
  178. result.get()
  179. class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
  180. DistributedGPUExecutorAsync):
  181. def __init__(self, *args, **kwargs):
  182. super().__init__(*args, **kwargs)
  183. self.driver_exec_model = make_async(self.driver_worker.execute_model)
  184. self.pp_locks: Optional[List[asyncio.Lock]] = None
  185. async def _driver_execute_model_async(
  186. self,
  187. execute_model_req: Optional[ExecuteModelRequest] = None
  188. ) -> List[SamplerOutput]:
  189. if not self.tp_driver_workers:
  190. return await self.driver_exec_model(execute_model_req)
  191. if self.pp_locks is None:
  192. # This locks each pipeline parallel stage so multiple virtual
  193. # engines can't execute on the same stage at the same time
  194. # We create the locks here to avoid creating them in the constructor
  195. # which uses a different asyncio loop.
  196. self.pp_locks = [
  197. asyncio.Lock()
  198. for _ in range(self.parallel_config.pipeline_parallel_size)
  199. ]
  200. tasks = [
  201. asyncio.create_task(
  202. _run_task_with_lock(self.driver_exec_model, self.pp_locks[0],
  203. execute_model_req))
  204. ]
  205. for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
  206. start=1):
  207. tasks.append(
  208. asyncio.create_task(
  209. _run_task_with_lock(driver_worker.execute_method_async,
  210. self.pp_locks[pp_rank],
  211. "execute_model", execute_model_req)))
  212. results = await asyncio.gather(*tasks)
  213. # Only the last PP stage has the final results.
  214. return results[-1]
  215. async def _start_worker_execution_loop(self):
  216. coros = [
  217. worker.execute_method_async("start_worker_execution_loop")
  218. for worker in self.non_driver_workers
  219. ]
  220. return await asyncio.gather(*coros)