multiproc_gpu_executor.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import asyncio
  2. import os
  3. from functools import partial
  4. from typing import Any, List, Optional
  5. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  6. from aphrodite.common.utils import (cuda_device_count_stateless,
  7. get_aphrodite_instance_id,
  8. get_distributed_init_method, get_open_port,
  9. make_async, update_environment_variables)
  10. from aphrodite.executor.distributed_gpu_executor import ( # yapf: disable
  11. DistributedGPUExecutor, DistributedGPUExecutorAsync)
  12. from aphrodite.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
  13. ResultHandler,
  14. WorkerMonitor)
  15. class MultiprocessingGPUExecutor(DistributedGPUExecutor):
  16. """Python multiprocessing-based multi-GPU executor"""
  17. def _init_executor(self) -> None:
  18. # Create the parallel GPU workers.
  19. world_size = self.parallel_config.tensor_parallel_size
  20. # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
  21. if "CUDA_VISIBLE_DEVICES" not in os.environ:
  22. update_environment_variables({
  23. "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
  24. })
  25. # Ensure that APHRODITE_INSTANCE_ID is set, to be inherited by workers
  26. os.environ["APHRODITE_INSTANCE_ID"] = get_aphrodite_instance_id()
  27. # Disable torch async compiling which won't work with daemonic processes
  28. os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
  29. assert world_size <= cuda_device_count_stateless(), (
  30. "please set tensor_parallel_size to less than max local gpu count")
  31. # Multiprocessing-based executor does not support multi-node setting.
  32. # Since it only works for single node, we can use the loopback address
  33. # 127.0.0.1 for communication.
  34. distributed_init_method = get_distributed_init_method(
  35. "127.0.0.1", get_open_port())
  36. if world_size == 1:
  37. self.workers = []
  38. self.worker_monitor = None
  39. else:
  40. result_handler = ResultHandler()
  41. self.workers = [
  42. ProcessWorkerWrapper(
  43. result_handler,
  44. partial(
  45. self._create_worker,
  46. rank=rank,
  47. local_rank=rank,
  48. distributed_init_method=distributed_init_method,
  49. )) for rank in range(1, world_size)
  50. ]
  51. self.worker_monitor = WorkerMonitor(self.workers, result_handler)
  52. result_handler.start()
  53. self.worker_monitor.start()
  54. self.driver_worker = self._create_worker(
  55. distributed_init_method=distributed_init_method)
  56. self._run_workers("init_device")
  57. self._run_workers("load_model",
  58. max_concurrent_workers=self.parallel_config.
  59. max_parallel_loading_workers)
  60. def shutdown(self):
  61. if (worker_monitor := getattr(self, "worker_monitor",
  62. None)) is not None:
  63. worker_monitor.close()
  64. def _driver_execute_model(
  65. self, execute_model_req: Optional[ExecuteModelRequest]
  66. ) -> Optional[List[SamplerOutput]]:
  67. """Run execute_model in the driver worker.
  68. Passing None will cause the driver to stop the model execution
  69. loop running in each of the remote workers.
  70. """
  71. return self.driver_worker.execute_model(execute_model_req)
  72. def _run_workers(
  73. self,
  74. method: str,
  75. *args,
  76. async_run_tensor_parallel_workers_only: bool = False,
  77. max_concurrent_workers: Optional[int] = None,
  78. **kwargs,
  79. ) -> Any:
  80. """Runs the given method on all workers.
  81. Args:
  82. async_run_tensor_parallel_workers_only: If True the method will be
  83. run only in the remote TP workers, not the driver worker.
  84. It will also be run asynchronously and return a list of futures
  85. rather than blocking on the results.
  86. """
  87. if max_concurrent_workers:
  88. raise NotImplementedError(
  89. "max_concurrent_workers is not supported yet.")
  90. # Start the workers first.
  91. worker_outputs = [
  92. worker.execute_method(method, *args, **kwargs)
  93. for worker in self.workers
  94. ]
  95. if async_run_tensor_parallel_workers_only:
  96. # Just return futures
  97. return worker_outputs
  98. driver_worker_method = getattr(self.driver_worker, method)
  99. driver_worker_output = driver_worker_method(*args, **kwargs)
  100. # Get the results of the workers.
  101. return [driver_worker_output
  102. ] + [output.get() for output in worker_outputs]
  103. def check_health(self) -> None:
  104. """Raises an error if engine is unhealthy."""
  105. if self.worker_monitor is not None and not self.worker_monitor.is_alive(
  106. ):
  107. raise RuntimeError("Worker processes are not running")
  108. def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
  109. """Wait for futures returned from _run_workers() with
  110. async_run_remote_workers_only to complete."""
  111. for result in parallel_worker_tasks:
  112. result.get()
  113. class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
  114. DistributedGPUExecutorAsync):
  115. def __init__(self, *args, **kwargs):
  116. super().__init__(*args, **kwargs)
  117. self.driver_exec_model = make_async(self.driver_worker.execute_model)
  118. async def _driver_execute_model_async(
  119. self,
  120. execute_model_req: Optional[ExecuteModelRequest] = None
  121. ) -> List[SamplerOutput]:
  122. return await self.driver_exec_model(execute_model_req)
  123. async def _start_worker_execution_loop(self):
  124. coros = [
  125. worker.execute_method_async("start_worker_execution_loop")
  126. for worker in self.workers
  127. ]
  128. return await asyncio.gather(*coros)