distributed_gpu_executor.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import asyncio
  2. from abc import abstractmethod
  3. from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
  4. from loguru import logger
  5. from aphrodite.common.sequence import ExecuteModelRequest
  6. from aphrodite.executor.executor_base import ExecutorAsyncBase
  7. from aphrodite.executor.gpu_executor import GPUExecutor
  8. from aphrodite.lora.request import LoRARequest
  9. from aphrodite.modeling.layers.sampler import SamplerOutput
  10. class DistributedGPUExecutor(GPUExecutor):
  11. """Abstract superclass of multi-GPU executor implementations."""
  12. def __init__(self, *args, **kwargs):
  13. # This is non-None when the execute model loop is running
  14. # in the parallel workers. It's a coroutine in the AsyncAphrodite case.
  15. self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
  16. # Updated by implementations that require additional args to be passed
  17. # to the _run_workers execute_model call
  18. self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
  19. super().__init__(*args, **kwargs)
  20. def determine_num_available_blocks(self) -> Tuple[int, int]:
  21. """Determine the number of available KV blocks.
  22. This invokes `determine_num_available_blocks` on each worker and takes
  23. the min of the results, guaranteeing that the selected cache sizes are
  24. compatible with all workers.
  25. Returns:
  26. - tuple[num_gpu_blocks, num_cpu_blocks]
  27. """
  28. # Get the maximum number of blocks that can be allocated on GPU and CPU.
  29. num_blocks = self._run_workers("determine_num_available_blocks", )
  30. # Since we use a shared centralized controller, we take the minimum
  31. # number of blocks across all workers to make sure all the memory
  32. # operators can be applied to all workers.
  33. num_gpu_blocks = min(b[0] for b in num_blocks)
  34. num_cpu_blocks = min(b[1] for b in num_blocks)
  35. return num_gpu_blocks, num_cpu_blocks
  36. def initialize_cache(self, num_gpu_blocks: int,
  37. num_cpu_blocks: int) -> None:
  38. """Initialize the KV cache in all workers.
  39. """
  40. # NOTE: We log here to avoid multiple logs when number of workers is
  41. # greater than one. We could log in the engine, but not all executors
  42. # have GPUs.
  43. logger.info(f"# GPU blocks: {num_gpu_blocks}, "
  44. f"# CPU blocks: {num_cpu_blocks}")
  45. logger.info(
  46. f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x" # noqa: E501
  47. )
  48. self.cache_config.num_gpu_blocks = num_gpu_blocks
  49. self.cache_config.num_cpu_blocks = num_cpu_blocks
  50. self._run_workers("initialize_cache",
  51. num_gpu_blocks=num_gpu_blocks,
  52. num_cpu_blocks=num_cpu_blocks)
  53. def execute_model(
  54. self,
  55. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  56. if self.parallel_worker_tasks is None:
  57. self.parallel_worker_tasks = self._run_workers(
  58. "start_worker_execution_loop",
  59. async_run_tensor_parallel_workers_only=True,
  60. **self.extra_execute_model_run_workers_kwargs)
  61. # Only the driver worker returns the sampling results.
  62. driver_outputs = self._driver_execute_model(execute_model_req)
  63. assert driver_outputs is not None
  64. return driver_outputs
  65. def stop_remote_worker_execution_loop(self) -> None:
  66. if self.parallel_worker_tasks is None:
  67. return
  68. self._driver_execute_model(execute_model_req=None)
  69. parallel_worker_tasks = self.parallel_worker_tasks
  70. self.parallel_worker_tasks = None
  71. # Ensure that workers exit model loop cleanly
  72. # (this will raise otherwise)
  73. self._wait_for_tasks_completion(parallel_worker_tasks)
  74. def add_lora(self, lora_request: LoRARequest) -> bool:
  75. assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
  76. return self._run_workers(
  77. "add_lora",
  78. lora_request=lora_request,
  79. )
  80. def remove_lora(self, lora_id: int) -> bool:
  81. assert lora_id > 0, "lora_id must be greater than 0."
  82. return self._run_workers(
  83. "remove_lora",
  84. lora_id=lora_id,
  85. )
  86. def list_loras(self) -> Set[int]:
  87. return self._run_workers("list_loras")
  88. def pin_lora(self, lora_id: int) -> bool:
  89. assert lora_id > 0, "lora_id must be greater than 0."
  90. return self._run_workers(
  91. "pin_lora",
  92. lora_id=lora_id,
  93. )
  94. def save_sharded_state(
  95. self,
  96. path: str,
  97. pattern: Optional[str] = None,
  98. max_size: Optional[int] = None,
  99. ) -> None:
  100. self._run_workers("save_sharded_state",
  101. path=path,
  102. pattern=pattern,
  103. max_size=max_size)
  104. @abstractmethod
  105. def _driver_execute_model(
  106. self, execute_model_req: Optional[ExecuteModelRequest]
  107. ) -> Optional[List[SamplerOutput]]:
  108. """Run execute_model in the driver worker.
  109. Passing None will cause the driver to stop the model execution loop
  110. running in each of the remote workers. In this case, this method
  111. returns None. Otherwise, this method returns the model output.
  112. """
  113. raise NotImplementedError
  114. @abstractmethod
  115. def _run_workers(
  116. self,
  117. method: str,
  118. *args,
  119. async_run_tensor_parallel_workers_only: bool = False,
  120. max_concurrent_workers: Optional[int] = None,
  121. **kwargs,
  122. ) -> Any:
  123. """Runs the given method on all workers.
  124. Args:
  125. async_run_tensor_parallel_workers_only: If True the method will be
  126. run only in the remote TP workers, not the driver worker.
  127. It will also be run asynchronously and return a list of futures
  128. rather than blocking on the results.
  129. """
  130. raise NotImplementedError
  131. @abstractmethod
  132. def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
  133. """Wait for futures returned from _run_workers() with
  134. async_run_remote_workers_only to complete."""
  135. raise NotImplementedError
  136. class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
  137. async def execute_model_async(
  138. self,
  139. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  140. if self.parallel_worker_tasks is None:
  141. # Start model execution loop running in the parallel workers
  142. self.parallel_worker_tasks = asyncio.create_task(
  143. self._start_worker_execution_loop())
  144. # Only the driver worker returns the sampling results.
  145. return await self._driver_execute_model_async(execute_model_req)
  146. async def stop_remote_worker_execution_loop_async(self) -> None:
  147. if self.parallel_worker_tasks is None:
  148. return
  149. await self._driver_execute_model_async()
  150. parallel_worker_tasks = self.parallel_worker_tasks
  151. self.parallel_worker_tasks = None
  152. # Ensure that workers exit model loop cleanly
  153. # (this will raise otherwise)
  154. await parallel_worker_tasks
  155. @abstractmethod
  156. async def _driver_execute_model_async(
  157. self,
  158. execute_model_req: Optional[ExecuteModelRequest] = None
  159. ) -> List[SamplerOutput]:
  160. """Execute the model asynchronously in the driver worker.
  161. Passing None will cause the driver to stop the model execution
  162. loop running in each of the remote workers.
  163. """
  164. raise NotImplementedError
  165. @abstractmethod
  166. async def _start_worker_execution_loop(self):
  167. """Run execution loop on all workers. It guarantees all workers run
  168. the loop or None of them is running the loop. Loop can be stopped by
  169. `stop_remote_worker_execution_loop`.
  170. The API is idempotent (guarantee only 1 loop run at any moment)."""
  171. raise NotImplementedError