multiproc_gpu_executor.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import asyncio
  2. import os
  3. from functools import partial
  4. from typing import Any, List, Optional
  5. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  6. from aphrodite.common.utils import (cuda_device_count_stateless,
  7. get_aphrodite_instance_id,
  8. get_distributed_init_method, get_ip,
  9. get_open_port, make_async)
  10. from aphrodite.executor.distributed_gpu_executor import ( # yapf: disable
  11. DistributedGPUExecutor, DistributedGPUExecutorAsync)
  12. from aphrodite.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
  13. ResultHandler,
  14. WorkerMonitor)
  15. class MultiprocessingGPUExecutor(DistributedGPUExecutor):
  16. """Python multiprocessing-based multi-GPU executor"""
  17. def _init_executor(self) -> None:
  18. # Create the parallel GPU workers.
  19. world_size = self.parallel_config.tensor_parallel_size
  20. # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
  21. if "CUDA_VISIBLE_DEVICES" not in os.environ:
  22. os.environ["CUDA_VISIBLE_DEVICES"] = (",".join(
  23. map(str, range(world_size))))
  24. # Ensure that APHRODITE_INSTANCE_ID is set, to be inherited by workers
  25. os.environ["APHRODITE_INSTANCE_ID"] = get_aphrodite_instance_id()
  26. # Disable torch async compiling which won't work with daemonic processes
  27. os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
  28. assert world_size <= cuda_device_count_stateless(), (
  29. "please set tensor_parallel_size to less than max local gpu count")
  30. distributed_init_method = get_distributed_init_method(
  31. get_ip(), get_open_port())
  32. if world_size == 1:
  33. self.workers = []
  34. self.worker_monitor = None
  35. else:
  36. result_handler = ResultHandler()
  37. self.workers = [
  38. ProcessWorkerWrapper(
  39. result_handler,
  40. partial(
  41. self._create_worker,
  42. rank=rank,
  43. local_rank=rank,
  44. distributed_init_method=distributed_init_method,
  45. )) for rank in range(1, world_size)
  46. ]
  47. self.worker_monitor = WorkerMonitor(self.workers, result_handler)
  48. result_handler.start()
  49. self.worker_monitor.start()
  50. self.driver_worker = self._create_worker(
  51. distributed_init_method=distributed_init_method)
  52. self._run_workers("init_device")
  53. self._run_workers("load_model",
  54. max_concurrent_workers=self.parallel_config.
  55. max_parallel_loading_workers)
  56. def shutdown(self):
  57. if (worker_monitor := getattr(self, "worker_monitor",
  58. None)) is not None:
  59. worker_monitor.close()
  60. def _driver_execute_model(
  61. self,
  62. execute_model_req: Optional[ExecuteModelRequest] = None
  63. ) -> List[SamplerOutput]:
  64. """Run execute_model in the driver worker.
  65. Passing None will cause the driver to stop the model execution
  66. loop running in each of the remote workers.
  67. """
  68. return self.driver_worker.execute_model(
  69. execute_model_req=execute_model_req)
  70. def _run_workers(
  71. self,
  72. method: str,
  73. *args,
  74. async_run_remote_workers_only: bool = False,
  75. max_concurrent_workers: Optional[int] = None,
  76. **kwargs,
  77. ) -> Any:
  78. """Runs the given method on all workers.
  79. Args:
  80. async_run_remote_workers_only: If True the method will be run only
  81. in the remote workers, not the driver worker. It will also be
  82. run asynchronously and return a list of futures rather than
  83. blocking on the results.
  84. """
  85. if max_concurrent_workers:
  86. raise NotImplementedError(
  87. "max_concurrent_workers is not supported yet.")
  88. # Start the workers first.
  89. worker_outputs = [
  90. worker.execute_method(method, *args, **kwargs)
  91. for worker in self.workers
  92. ]
  93. if async_run_remote_workers_only:
  94. # Just return futures
  95. return worker_outputs
  96. driver_worker_method = getattr(self.driver_worker, method)
  97. driver_worker_output = driver_worker_method(*args, **kwargs)
  98. # Get the results of the workers.
  99. return [driver_worker_output
  100. ] + [output.get() for output in worker_outputs]
  101. def check_health(self) -> None:
  102. """Raises an error if engine is unhealthy."""
  103. if self.worker_monitor is not None and not self.worker_monitor.is_alive(
  104. ):
  105. raise RuntimeError("Worker processes are not running")
  106. def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
  107. """Wait for futures returned from _run_workers() with
  108. async_run_remote_workers_only to complete."""
  109. for result in parallel_worker_tasks:
  110. result.get()
  111. class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
  112. DistributedGPUExecutorAsync):
  113. def __init__(self, *args, **kwargs):
  114. super().__init__(*args, **kwargs)
  115. self.driver_exec_model = make_async(self.driver_worker.execute_model)
  116. async def _driver_execute_model_async(
  117. self,
  118. execute_model_req: Optional[ExecuteModelRequest] = None
  119. ) -> List[SamplerOutput]:
  120. return await self.driver_exec_model(execute_model_req)
  121. async def _start_worker_execution_loop(self):
  122. coros = [
  123. worker.execute_method_async("start_worker_execution_loop")
  124. for worker in self.workers
  125. ]
  126. return await asyncio.gather(*coros)