multiproc_gpu_executor.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import asyncio
  2. import os
  3. import signal
  4. import weakref
  5. from functools import partial
  6. from typing import Any, List, Optional
  7. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  8. from aphrodite.common.utils import (_run_task_with_lock,
  9. cuda_device_count_stateless,
  10. error_on_invalid_device_count_status,
  11. get_aphrodite_instance_id,
  12. get_distributed_init_method, get_open_port,
  13. make_async, update_environment_variables)
  14. from aphrodite.executor.distributed_gpu_executor import ( # yapf: disable
  15. DistributedGPUExecutor, DistributedGPUExecutorAsync)
  16. from aphrodite.executor.gpu_executor import create_worker
  17. from aphrodite.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
  18. ResultHandler,
  19. WorkerMonitor)
  20. from aphrodite.triton_utils import maybe_set_triton_cache_manager
  21. class MultiprocessingGPUExecutor(DistributedGPUExecutor):
  22. """Python multiprocessing-based multi-GPU executor"""
  23. uses_ray: bool = False
  24. def _init_executor(self) -> None:
  25. # Create the parallel GPU workers.
  26. world_size = self.parallel_config.world_size
  27. tensor_parallel_size = self.parallel_config.tensor_parallel_size
  28. # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
  29. if "CUDA_VISIBLE_DEVICES" not in os.environ:
  30. update_environment_variables({
  31. "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
  32. })
  33. # Ensure that APHRODITE_INSTANCE_ID is set, to be inherited by workers
  34. os.environ["APHRODITE_INSTANCE_ID"] = get_aphrodite_instance_id()
  35. # Disable torch async compiling which won't work with daemonic processes
  36. os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
  37. # Set OMP_NUM_THREADS to 1 if it is not set explicitly, avoids CPU
  38. # contention amongst the shards
  39. if "OMP_NUM_THREADS" not in os.environ:
  40. os.environ["OMP_NUM_THREADS"] = "1"
  41. if world_size > 1:
  42. maybe_set_triton_cache_manager()
  43. cuda_device_count = cuda_device_count_stateless()
  44. # Use confusing message for more common TP-only case.
  45. assert tensor_parallel_size <= cuda_device_count, (
  46. f"please set tensor_parallel_size ({tensor_parallel_size}) "
  47. f"to less than max local gpu count ({cuda_device_count})")
  48. assert world_size <= cuda_device_count, (
  49. f"please ensure that world_size ({world_size}) "
  50. f"is less than than max local gpu count ({cuda_device_count})")
  51. error_on_invalid_device_count_status()
  52. # Multiprocessing-based executor does not support multi-node setting.
  53. # Since it only works for single node, we can use the loopback address
  54. # 127.0.0.1 for communication.
  55. distributed_init_method = get_distributed_init_method(
  56. "127.0.0.1", get_open_port())
  57. self.workers: List[ProcessWorkerWrapper] = []
  58. # This is the list of workers that are rank 0 of each TP group EXCEPT
  59. # global rank 0. These are the workers that will broadcast to the
  60. # rest of the workers.
  61. self.tp_driver_workers: List[ProcessWorkerWrapper] = []
  62. # This is the list of workers that are not drivers and not the first
  63. # worker in a TP group. These are the workers that will be
  64. # broadcasted to.
  65. self.non_driver_workers: List[ProcessWorkerWrapper] = []
  66. if world_size == 1:
  67. self.worker_monitor = None
  68. else:
  69. result_handler = ResultHandler()
  70. for rank in range(1, world_size):
  71. worker = ProcessWorkerWrapper(
  72. result_handler,
  73. partial(
  74. create_worker,
  75. **self._get_create_worker_kwargs(
  76. rank=rank,
  77. local_rank=rank,
  78. distributed_init_method=distributed_init_method,
  79. )))
  80. self.workers.append(worker)
  81. if rank % tensor_parallel_size == 0:
  82. self.tp_driver_workers.append(worker)
  83. else:
  84. self.non_driver_workers.append(worker)
  85. self.worker_monitor = WorkerMonitor(self.workers, result_handler)
  86. result_handler.start()
  87. self.worker_monitor.start()
  88. # Set up signal handlers to shutdown the executor cleanly
  89. # sometimes gc does not work well
  90. # Use weakref to avoid holding a reference to self
  91. ref = weakref.ref(self)
  92. def shutdown(signum, frame):
  93. if executor := ref():
  94. executor.shutdown()
  95. signal.signal(signal.SIGINT, shutdown)
  96. signal.signal(signal.SIGTERM, shutdown)
  97. self.driver_worker = self._create_worker(
  98. distributed_init_method=distributed_init_method)
  99. self._run_workers("init_device")
  100. self._run_workers("load_model",
  101. max_concurrent_workers=self.parallel_config.
  102. max_parallel_loading_workers)
  103. def shutdown(self):
  104. if (worker_monitor := getattr(self, "worker_monitor",
  105. None)) is not None:
  106. worker_monitor.close()
  107. def _driver_execute_model(
  108. self, execute_model_req: Optional[ExecuteModelRequest]
  109. ) -> Optional[List[SamplerOutput]]:
  110. """Run execute_model in the driver worker.
  111. Passing None will cause the driver to stop the model execution
  112. loop running in each of the remote workers.
  113. """
  114. return self.driver_worker.execute_model(execute_model_req)
  115. def _run_workers(
  116. self,
  117. method: str,
  118. *args,
  119. async_run_tensor_parallel_workers_only: bool = False,
  120. max_concurrent_workers: Optional[int] = None,
  121. **kwargs,
  122. ) -> Any:
  123. """Runs the given method on all workers.
  124. Args:
  125. async_run_tensor_parallel_workers_only: If True the method will be
  126. run only in the remote TP workers, not the driver worker.
  127. It will also be run asynchronously and return a list of futures
  128. rather than blocking on the results.
  129. """
  130. if max_concurrent_workers:
  131. raise NotImplementedError(
  132. "max_concurrent_workers is not supported yet.")
  133. if async_run_tensor_parallel_workers_only:
  134. # Run only non-driver workers and just return futures.
  135. return [
  136. worker.execute_method(method, *args, **kwargs)
  137. for worker in self.non_driver_workers
  138. ]
  139. # Start all remote workers first.
  140. worker_outputs = [
  141. worker.execute_method(method, *args, **kwargs)
  142. for worker in self.workers
  143. ]
  144. driver_worker_method = getattr(self.driver_worker, method)
  145. driver_worker_output = driver_worker_method(*args, **kwargs)
  146. # Get the results of the workers.
  147. return [driver_worker_output
  148. ] + [output.get() for output in worker_outputs]
  149. def check_health(self) -> None:
  150. """Raises an error if engine is unhealthy."""
  151. if self.worker_monitor is not None and not self.worker_monitor.is_alive(
  152. ):
  153. raise RuntimeError("Worker processes are not running")
  154. def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
  155. """Wait for futures returned from _run_workers() with
  156. async_run_remote_workers_only to complete."""
  157. for result in parallel_worker_tasks:
  158. result.get()
  159. class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
  160. DistributedGPUExecutorAsync):
  161. def __init__(self, *args, **kwargs):
  162. super().__init__(*args, **kwargs)
  163. self.driver_exec_model = make_async(self.driver_worker.execute_model)
  164. self.pp_locks: Optional[List[asyncio.Lock]] = None
  165. async def _driver_execute_model_async(
  166. self,
  167. execute_model_req: Optional[ExecuteModelRequest] = None
  168. ) -> List[SamplerOutput]:
  169. if not self.tp_driver_workers:
  170. return await self.driver_exec_model(execute_model_req)
  171. if self.pp_locks is None:
  172. # This locks each pipeline parallel stage so multiple virtual
  173. # engines can't execute on the same stage at the same time
  174. # We create the locks here to avoid creating them in the constructor
  175. # which uses a different asyncio loop.
  176. self.pp_locks = [
  177. asyncio.Lock()
  178. for _ in range(self.parallel_config.pipeline_parallel_size)
  179. ]
  180. tasks = [
  181. asyncio.create_task(
  182. _run_task_with_lock(self.driver_exec_model, self.pp_locks[0],
  183. execute_model_req))
  184. ]
  185. for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
  186. start=1):
  187. tasks.append(
  188. asyncio.create_task(
  189. _run_task_with_lock(driver_worker.execute_method_async,
  190. self.pp_locks[pp_rank],
  191. "execute_model", execute_model_req)))
  192. results = await asyncio.gather(*tasks)
  193. # Only the last PP stage has the final results.
  194. return results[-1]
  195. async def _start_worker_execution_loop(self):
  196. coros = [
  197. worker.execute_method_async("start_worker_execution_loop")
  198. for worker in self.non_driver_workers
  199. ]
  200. return await asyncio.gather(*coros)