import asyncio from abc import abstractmethod from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union from loguru import logger from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput from aphrodite.executor.executor_base import ExecutorAsyncBase from aphrodite.executor.gpu_executor import GPUExecutor from aphrodite.lora.request import LoRARequest class DistributedGPUExecutor(GPUExecutor): """Abstract superclass of multi-GPU executor implementations.""" def __init__(self, *args, **kwargs): # This is non-None when the execute model loop is running # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None # Updated by implementations that require additional args to be passed # to the _run_workers execute_model call self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {} super().__init__(*args, **kwargs) def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks. This invokes `determine_num_available_blocks` on each worker and takes the min of the results, guaranteeing that the selected cache sizes are compatible with all workers. Returns: - tuple[num_gpu_blocks, num_cpu_blocks] """ # Get the maximum number of blocks that can be allocated on GPU and CPU. num_blocks = self._run_workers("determine_num_available_blocks", ) # Since we use a shared centralized controller, we take the minimum # number of blocks across all workers to make sure all the memory # operators can be applied to all workers. num_gpu_blocks = min(b[0] for b in num_blocks) num_cpu_blocks = min(b[1] for b in num_blocks) return num_gpu_blocks, num_cpu_blocks def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: """Initialize the KV cache in all workers. """ # NOTE: We log here to avoid multiple logs when number of workers is # greater than one. We could log in the engine, but not all executors # have GPUs. logger.info(f"# GPU blocks: {num_gpu_blocks}, " f"# CPU blocks: {num_cpu_blocks}") logger.info( f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x" # noqa: E501 ) self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks self._run_workers("initialize_cache", num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) def execute_model( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: if self.parallel_worker_tasks is None: self.parallel_worker_tasks = self._run_workers( "start_worker_execution_loop", async_run_remote_workers_only=True, **self.extra_execute_model_run_workers_kwargs) # Only the driver worker returns the sampling results. return self._driver_execute_model(execute_model_req) def stop_remote_worker_execution_loop(self) -> None: if self.parallel_worker_tasks is None: return self._driver_execute_model() parallel_worker_tasks = self.parallel_worker_tasks self.parallel_worker_tasks = None # Ensure that workers exit model loop cleanly # (this will raise otherwise) self._wait_for_tasks_completion(parallel_worker_tasks) def add_lora(self, lora_request: LoRARequest) -> bool: assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." return self._run_workers( "add_lora", lora_request=lora_request, ) def remove_lora(self, lora_id: int) -> bool: assert lora_id > 0, "lora_id must be greater than 0." return self._run_workers( "remove_lora", lora_id=lora_id, ) def list_loras(self) -> Set[int]: return self._run_workers("list_loras") def save_sharded_state( self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None, ) -> None: self._run_workers("save_sharded_state", path=path, pattern=pattern, max_size=max_size) @abstractmethod def _driver_execute_model( self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: """Run execute_model in the driver worker. Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ raise NotImplementedError @abstractmethod def _run_workers( self, method: str, *args, async_run_remote_workers_only: bool = False, max_concurrent_workers: Optional[int] = None, **kwargs, ) -> Any: """Runs the given method on all workers. Args: async_run_remote_workers_only: If True the method will be run only in the remote workers, not the driver worker. It will also be run asynchronously and return a list of futures rather than blocking on the results. """ raise NotImplementedError @abstractmethod def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: """Wait for futures returned from _run_workers() with async_run_remote_workers_only to complete.""" raise NotImplementedError class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase): async def execute_model_async( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: if self.parallel_worker_tasks is None: # Start model execution loop running in the parallel workers self.parallel_worker_tasks = asyncio.create_task( self._start_worker_execution_loop()) # Only the driver worker returns the sampling results. return await self._driver_execute_model_async(execute_model_req) async def stop_remote_worker_execution_loop_async(self) -> None: if self.parallel_worker_tasks is None: return await self._driver_execute_model_async() parallel_worker_tasks = self.parallel_worker_tasks self.parallel_worker_tasks = None # Ensure that workers exit model loop cleanly # (this will raise otherwise) await parallel_worker_tasks @abstractmethod async def _driver_execute_model_async( self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: """Execute the model asynchronously in the driver worker. Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ raise NotImplementedError @abstractmethod async def _start_worker_execution_loop(self): """Run execution loop on all workers. It guarantees all workers run the loop or None of them is running the loop. Loop can be stopped by `stop_remote_worker_execution_loop`. The API is idempotent (guarantee only 1 loop run at any moment).""" raise NotImplementedError