multiproc_gpu_executor.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import asyncio
  2. import os
  3. from functools import partial
  4. from typing import Any, Dict, Optional, Tuple
  5. from aphrodite.common.utils import (get_aphrodite_instance_id,
  6. get_distributed_init_method, get_ip,
  7. get_open_port, make_async)
  8. from aphrodite.executor.distributed_gpu_executor import ( # yapf: disable
  9. DistributedGPUExecutor, DistributedGPUExecutorAsync)
  10. from aphrodite.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
  11. ResultHandler,
  12. WorkerMonitor)
  13. class MultiprocessingGPUExecutor(DistributedGPUExecutor):
  14. """Python multiprocessing-based multi-GPU executor"""
  15. def _init_executor(self) -> None:
  16. # Create the parallel GPU workers.
  17. world_size = self.parallel_config.tensor_parallel_size
  18. # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
  19. if "CUDA_VISIBLE_DEVICES" not in os.environ:
  20. os.environ["CUDA_VISIBLE_DEVICES"] = (",".join(
  21. map(str, range(world_size))))
  22. # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
  23. os.environ["APHRODITE_INSTANCE_ID"] = get_aphrodite_instance_id()
  24. from torch.cuda import device_count
  25. assert world_size <= device_count(), (
  26. "please set tensor_parallel_size to less than max local gpu count")
  27. distributed_init_method = get_distributed_init_method(
  28. get_ip(), get_open_port())
  29. if world_size == 1:
  30. self.workers = []
  31. else:
  32. result_handler = ResultHandler()
  33. self.workers = [
  34. ProcessWorkerWrapper(
  35. result_handler,
  36. partial(
  37. self._create_worker,
  38. rank=rank,
  39. local_rank=rank,
  40. distributed_init_method=distributed_init_method,
  41. )) for rank in range(1, world_size)
  42. ]
  43. self.worker_monitor = WorkerMonitor(self.workers, result_handler)
  44. result_handler.start()
  45. self.worker_monitor.start()
  46. self.driver_worker = self._create_worker(
  47. distributed_init_method=distributed_init_method)
  48. self._run_workers("init_device")
  49. self._run_workers("load_model",
  50. max_concurrent_workers=self.parallel_config.
  51. max_parallel_loading_workers)
  52. def shutdown(self):
  53. if (worker_monitor := getattr(self, "worker_monitor",
  54. None)) is not None:
  55. worker_monitor.close()
  56. def _run_workers(
  57. self,
  58. method: str,
  59. *args,
  60. driver_args: Optional[Tuple[Any, ...]] = None,
  61. driver_kwargs: Optional[Dict[str, Any]] = None,
  62. max_concurrent_workers: Optional[int] = None,
  63. **kwargs,
  64. ) -> Any:
  65. """Runs the given method on all workers."""
  66. if max_concurrent_workers:
  67. raise NotImplementedError(
  68. "max_concurrent_workers is not supported yet.")
  69. # Start the workers first.
  70. worker_outputs = [
  71. worker.execute_method(method, *args, **kwargs)
  72. for worker in self.workers
  73. ]
  74. if driver_args is None:
  75. driver_args = args
  76. if driver_kwargs is None:
  77. driver_kwargs = kwargs
  78. # Start the driver worker after all the ray workers.
  79. driver_worker_method = getattr(self.driver_worker, method)
  80. driver_worker_output = driver_worker_method(*driver_args,
  81. **driver_kwargs)
  82. # Get the results of the workers.
  83. return [driver_worker_output
  84. ] + [output.get() for output in worker_outputs]
  85. def check_health(self) -> None:
  86. """Raises an error if engine is unhealthy."""
  87. if not self.worker_monitor.is_alive():
  88. raise RuntimeError("Worker processes are not running")
  89. class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
  90. DistributedGPUExecutorAsync):
  91. async def _run_workers_async(
  92. self,
  93. method: str,
  94. *args,
  95. driver_args: Optional[Tuple[Any, ...]] = None,
  96. driver_kwargs: Optional[Dict[str, Any]] = None,
  97. **kwargs,
  98. ) -> Any:
  99. """Runs the given method on all workers."""
  100. if driver_args is None:
  101. driver_args = args
  102. if driver_kwargs is None:
  103. driver_kwargs = kwargs
  104. driver_executor = make_async(getattr(self.driver_worker, method))
  105. # Run all the workers asynchronously.
  106. coros = [driver_executor(*driver_args, **driver_kwargs)] + [
  107. worker.execute_method_async(method, *args, **kwargs)
  108. for worker in self.workers
  109. ]
  110. return await asyncio.gather(*coros)