multiproc_gpu_executor.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import asyncio
  2. import os
  3. from functools import partial
  4. from typing import Any, List, Optional
  5. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  6. from aphrodite.common.utils import (get_aphrodite_instance_id,
  7. get_distributed_init_method, get_ip,
  8. get_open_port, make_async)
  9. from aphrodite.executor.distributed_gpu_executor import ( # yapf: disable
  10. DistributedGPUExecutor, DistributedGPUExecutorAsync)
  11. from aphrodite.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
  12. ResultHandler,
  13. WorkerMonitor)
  14. class MultiprocessingGPUExecutor(DistributedGPUExecutor):
  15. """Python multiprocessing-based multi-GPU executor"""
  16. def _init_executor(self) -> None:
  17. # Create the parallel GPU workers.
  18. world_size = self.parallel_config.tensor_parallel_size
  19. # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
  20. if "CUDA_VISIBLE_DEVICES" not in os.environ:
  21. os.environ["CUDA_VISIBLE_DEVICES"] = (",".join(
  22. map(str, range(world_size))))
  23. # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
  24. os.environ["APHRODITE_INSTANCE_ID"] = get_aphrodite_instance_id()
  25. from torch.cuda import device_count
  26. assert world_size <= device_count(), (
  27. "please set tensor_parallel_size to less than max local gpu count")
  28. distributed_init_method = get_distributed_init_method(
  29. get_ip(), get_open_port())
  30. if world_size == 1:
  31. self.workers = []
  32. else:
  33. result_handler = ResultHandler()
  34. self.workers = [
  35. ProcessWorkerWrapper(
  36. result_handler,
  37. partial(
  38. self._create_worker,
  39. rank=rank,
  40. local_rank=rank,
  41. distributed_init_method=distributed_init_method,
  42. )) for rank in range(1, world_size)
  43. ]
  44. self.worker_monitor = WorkerMonitor(self.workers, result_handler)
  45. result_handler.start()
  46. self.worker_monitor.start()
  47. self.driver_worker = self._create_worker(
  48. distributed_init_method=distributed_init_method)
  49. self._run_workers("init_device")
  50. self._run_workers("load_model",
  51. max_concurrent_workers=self.parallel_config.
  52. max_parallel_loading_workers)
  53. def shutdown(self):
  54. if (worker_monitor := getattr(self, "worker_monitor",
  55. None)) is not None:
  56. worker_monitor.close()
  57. def _driver_execute_model(
  58. self,
  59. execute_model_req: Optional[ExecuteModelRequest] = None
  60. ) -> List[SamplerOutput]:
  61. """Run execute_model in the driver worker.
  62. Passing None will cause the driver to stop the model execution
  63. loop running in each of the remote workers.
  64. """
  65. return self.driver_worker.execute_model(
  66. execute_model_req=execute_model_req)
  67. def _run_workers(
  68. self,
  69. method: str,
  70. *args,
  71. async_run_remote_workers_only: bool = False,
  72. max_concurrent_workers: Optional[int] = None,
  73. **kwargs,
  74. ) -> Any:
  75. """Runs the given method on all workers.
  76. Args:
  77. async_run_remote_workers_only: If True the method will be run only
  78. in the remote workers, not the driver worker. It will also be
  79. run asynchronously and return a list of futures rather than
  80. blocking on the results.
  81. """
  82. if max_concurrent_workers:
  83. raise NotImplementedError(
  84. "max_concurrent_workers is not supported yet.")
  85. # Start the workers first.
  86. worker_outputs = [
  87. worker.execute_method(method, *args, **kwargs)
  88. for worker in self.workers
  89. ]
  90. if async_run_remote_workers_only:
  91. # Just return futures
  92. return worker_outputs
  93. driver_worker_method = getattr(self.driver_worker, method)
  94. driver_worker_output = driver_worker_method(*args, **kwargs)
  95. # Get the results of the workers.
  96. return [driver_worker_output
  97. ] + [output.get() for output in worker_outputs]
  98. def check_health(self) -> None:
  99. """Raises an error if engine is unhealthy."""
  100. if not self.worker_monitor.is_alive():
  101. raise RuntimeError("Worker processes are not running")
  102. def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
  103. """Wait for futures returned from _run_workers() with
  104. async_run_remote_workers_only to complete."""
  105. for result in parallel_worker_tasks:
  106. result.get()
  107. class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
  108. DistributedGPUExecutorAsync):
  109. def __init__(self, *args, **kwargs):
  110. super().__init__(*args, **kwargs)
  111. self.driver_exec_model = make_async(self.driver_worker.execute_model)
  112. async def _driver_execute_model_async(
  113. self,
  114. execute_model_req: Optional[ExecuteModelRequest] = None
  115. ) -> List[SamplerOutput]:
  116. return await self.driver_exec_model(execute_model_req)
  117. async def _start_worker_execution_loop(self):
  118. coros = [
  119. worker.execute_method_async("start_worker_execution_loop")
  120. for worker in self.workers
  121. ]
  122. return await asyncio.gather(*coros)