distributed_gpu_executor.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import asyncio
  2. from abc import abstractmethod
  3. from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
  4. from loguru import logger
  5. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  6. from aphrodite.executor.executor_base import ExecutorAsyncBase
  7. from aphrodite.executor.gpu_executor import GPUExecutor
  8. from aphrodite.lora.request import LoRARequest
  9. class DistributedGPUExecutor(GPUExecutor):
  10. """Abstract superclass of multi-GPU executor implementations."""
  11. def __init__(self, *args, **kwargs):
  12. # This is non-None when the execute model loop is running
  13. # in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
  14. self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
  15. # Updated by implementations that require additional args to be passed
  16. # to the _run_workers execute_model call
  17. self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
  18. super().__init__(*args, **kwargs)
  19. def determine_num_available_blocks(self) -> Tuple[int, int]:
  20. """Determine the number of available KV blocks.
  21. This invokes `determine_num_available_blocks` on each worker and takes
  22. the min of the results, guaranteeing that the selected cache sizes are
  23. compatible with all workers.
  24. Returns:
  25. - tuple[num_gpu_blocks, num_cpu_blocks]
  26. """
  27. # Get the maximum number of blocks that can be allocated on GPU and CPU.
  28. num_blocks = self._run_workers("determine_num_available_blocks", )
  29. # Since we use a shared centralized controller, we take the minimum
  30. # number of blocks across all workers to make sure all the memory
  31. # operators can be applied to all workers.
  32. num_gpu_blocks = min(b[0] for b in num_blocks)
  33. num_cpu_blocks = min(b[1] for b in num_blocks)
  34. return num_gpu_blocks, num_cpu_blocks
  35. def initialize_cache(self, num_gpu_blocks: int,
  36. num_cpu_blocks: int) -> None:
  37. """Initialize the KV cache in all workers.
  38. """
  39. # NOTE: We log here to avoid multiple logs when number of workers is
  40. # greater than one. We could log in the engine, but not all executors
  41. # have GPUs.
  42. logger.info(f"# GPU blocks: {num_gpu_blocks}, "
  43. f"# CPU blocks: {num_cpu_blocks}")
  44. logger.info(
  45. f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x" # noqa: E501
  46. )
  47. self.cache_config.num_gpu_blocks = num_gpu_blocks
  48. self.cache_config.num_cpu_blocks = num_cpu_blocks
  49. self._run_workers("initialize_cache",
  50. num_gpu_blocks=num_gpu_blocks,
  51. num_cpu_blocks=num_cpu_blocks)
  52. def execute_model(
  53. self,
  54. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  55. if self.parallel_worker_tasks is None:
  56. self.parallel_worker_tasks = self._run_workers(
  57. "start_worker_execution_loop",
  58. async_run_remote_workers_only=True,
  59. **self.extra_execute_model_run_workers_kwargs)
  60. # Only the driver worker returns the sampling results.
  61. return self._driver_execute_model(execute_model_req)
  62. def stop_remote_worker_execution_loop(self) -> None:
  63. if self.parallel_worker_tasks is None:
  64. return
  65. self._driver_execute_model()
  66. parallel_worker_tasks = self.parallel_worker_tasks
  67. self.parallel_worker_tasks = None
  68. # Ensure that workers exit model loop cleanly
  69. # (this will raise otherwise)
  70. self._wait_for_tasks_completion(parallel_worker_tasks)
  71. def add_lora(self, lora_request: LoRARequest) -> bool:
  72. assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
  73. return self._run_workers(
  74. "add_lora",
  75. lora_request=lora_request,
  76. )
  77. def remove_lora(self, lora_id: int) -> bool:
  78. assert lora_id > 0, "lora_id must be greater than 0."
  79. return self._run_workers(
  80. "remove_lora",
  81. lora_id=lora_id,
  82. )
  83. def list_loras(self) -> Set[int]:
  84. return self._run_workers("list_loras")
  85. def save_sharded_state(
  86. self,
  87. path: str,
  88. pattern: Optional[str] = None,
  89. max_size: Optional[int] = None,
  90. ) -> None:
  91. self._run_workers("save_sharded_state",
  92. path=path,
  93. pattern=pattern,
  94. max_size=max_size)
  95. @abstractmethod
  96. def _driver_execute_model(
  97. self,
  98. execute_model_req: Optional[ExecuteModelRequest] = None
  99. ) -> List[SamplerOutput]:
  100. """Run execute_model in the driver worker.
  101. Passing None will cause the driver to stop the model execution
  102. loop running in each of the remote workers.
  103. """
  104. raise NotImplementedError
  105. @abstractmethod
  106. def _run_workers(
  107. self,
  108. method: str,
  109. *args,
  110. async_run_remote_workers_only: bool = False,
  111. max_concurrent_workers: Optional[int] = None,
  112. **kwargs,
  113. ) -> Any:
  114. """Runs the given method on all workers.
  115. Args:
  116. async_run_remote_workers_only: If True the method will be run only
  117. in the remote workers, not the driver worker. It will also be
  118. run asynchronously and return a list of futures rather than
  119. blocking on the results.
  120. """
  121. raise NotImplementedError
  122. @abstractmethod
  123. def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
  124. """Wait for futures returned from _run_workers() with
  125. async_run_remote_workers_only to complete."""
  126. raise NotImplementedError
  127. class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
  128. async def execute_model_async(
  129. self,
  130. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  131. if self.parallel_worker_tasks is None:
  132. # Start model execution loop running in the parallel workers
  133. self.parallel_worker_tasks = asyncio.create_task(
  134. self._start_worker_execution_loop())
  135. # Only the driver worker returns the sampling results.
  136. return await self._driver_execute_model_async(execute_model_req)
  137. async def stop_remote_worker_execution_loop_async(self) -> None:
  138. if self.parallel_worker_tasks is None:
  139. return
  140. await self._driver_execute_model_async()
  141. parallel_worker_tasks = self.parallel_worker_tasks
  142. self.parallel_worker_tasks = None
  143. # Ensure that workers exit model loop cleanly
  144. # (this will raise otherwise)
  145. await parallel_worker_tasks
  146. @abstractmethod
  147. async def _driver_execute_model_async(
  148. self,
  149. execute_model_req: Optional[ExecuteModelRequest] = None
  150. ) -> List[SamplerOutput]:
  151. """Execute the model asynchronously in the driver worker.
  152. Passing None will cause the driver to stop the model execution
  153. loop running in each of the remote workers.
  154. """
  155. raise NotImplementedError
  156. @abstractmethod
  157. async def _start_worker_execution_loop(self):
  158. """Run execution loop on all workers. It guarantees all workers run
  159. the loop or None of them is running the loop. Loop can be stopped by
  160. `stop_remote_worker_execution_loop`.
  161. The API is idempotent (guarantee only 1 loop run at any moment)."""
  162. raise NotImplementedError