123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204 |
- import asyncio
- from abc import abstractmethod
- from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
- from loguru import logger
- from aphrodite.common.sequence import ExecuteModelRequest
- from aphrodite.executor.executor_base import ExecutorAsyncBase
- from aphrodite.executor.gpu_executor import GPUExecutor
- from aphrodite.lora.request import LoRARequest
- from aphrodite.modeling.layers.sampler import SamplerOutput
- class DistributedGPUExecutor(GPUExecutor):
- """Abstract superclass of multi-GPU executor implementations."""
- def __init__(self, *args, **kwargs):
- # This is non-None when the execute model loop is running
- # in the parallel workers. It's a coroutine in the AsyncAphrodite case.
- self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
- # Updated by implementations that require additional args to be passed
- # to the _run_workers execute_model call
- self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
- super().__init__(*args, **kwargs)
- def determine_num_available_blocks(self) -> Tuple[int, int]:
- """Determine the number of available KV blocks.
- This invokes `determine_num_available_blocks` on each worker and takes
- the min of the results, guaranteeing that the selected cache sizes are
- compatible with all workers.
- Returns:
- - tuple[num_gpu_blocks, num_cpu_blocks]
- """
- # Get the maximum number of blocks that can be allocated on GPU and CPU.
- num_blocks = self._run_workers("determine_num_available_blocks", )
- # Since we use a shared centralized controller, we take the minimum
- # number of blocks across all workers to make sure all the memory
- # operators can be applied to all workers.
- num_gpu_blocks = min(b[0] for b in num_blocks)
- num_cpu_blocks = min(b[1] for b in num_blocks)
- return num_gpu_blocks, num_cpu_blocks
- def initialize_cache(self, num_gpu_blocks: int,
- num_cpu_blocks: int) -> None:
- """Initialize the KV cache in all workers.
- """
- # NOTE: We log here to avoid multiple logs when number of workers is
- # greater than one. We could log in the engine, but not all executors
- # have GPUs.
- logger.info(f"# GPU blocks: {num_gpu_blocks}, "
- f"# CPU blocks: {num_cpu_blocks}")
- logger.info(
- f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x" # noqa: E501
- )
- self.cache_config.num_gpu_blocks = num_gpu_blocks
- self.cache_config.num_cpu_blocks = num_cpu_blocks
- self._run_workers("initialize_cache",
- num_gpu_blocks=num_gpu_blocks,
- num_cpu_blocks=num_cpu_blocks)
- def execute_model(
- self,
- execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
- if self.parallel_worker_tasks is None:
- self.parallel_worker_tasks = self._run_workers(
- "start_worker_execution_loop",
- async_run_tensor_parallel_workers_only=True,
- **self.extra_execute_model_run_workers_kwargs)
- # Only the driver worker returns the sampling results.
- driver_outputs = self._driver_execute_model(execute_model_req)
- assert driver_outputs is not None
- return driver_outputs
- def stop_remote_worker_execution_loop(self) -> None:
- if self.parallel_worker_tasks is None:
- return
- self._driver_execute_model(execute_model_req=None)
- parallel_worker_tasks = self.parallel_worker_tasks
- self.parallel_worker_tasks = None
- # Ensure that workers exit model loop cleanly
- # (this will raise otherwise)
- self._wait_for_tasks_completion(parallel_worker_tasks)
- def add_lora(self, lora_request: LoRARequest) -> bool:
- assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
- return self._run_workers(
- "add_lora",
- lora_request=lora_request,
- )
- def remove_lora(self, lora_id: int) -> bool:
- assert lora_id > 0, "lora_id must be greater than 0."
- return self._run_workers(
- "remove_lora",
- lora_id=lora_id,
- )
- def list_loras(self) -> Set[int]:
- return self._run_workers("list_loras")
- def pin_lora(self, lora_id: int) -> bool:
- assert lora_id > 0, "lora_id must be greater than 0."
- return self._run_workers(
- "pin_lora",
- lora_id=lora_id,
- )
- def save_sharded_state(
- self,
- path: str,
- pattern: Optional[str] = None,
- max_size: Optional[int] = None,
- ) -> None:
- self._run_workers("save_sharded_state",
- path=path,
- pattern=pattern,
- max_size=max_size)
- @abstractmethod
- def _driver_execute_model(
- self, execute_model_req: Optional[ExecuteModelRequest]
- ) -> Optional[List[SamplerOutput]]:
- """Run execute_model in the driver worker.
- Passing None will cause the driver to stop the model execution loop
- running in each of the remote workers. In this case, this method
- returns None. Otherwise, this method returns the model output.
- """
- raise NotImplementedError
- @abstractmethod
- def _run_workers(
- self,
- method: str,
- *args,
- async_run_tensor_parallel_workers_only: bool = False,
- max_concurrent_workers: Optional[int] = None,
- **kwargs,
- ) -> Any:
- """Runs the given method on all workers.
- Args:
- async_run_tensor_parallel_workers_only: If True the method will be
- run only in the remote TP workers, not the driver worker.
- It will also be run asynchronously and return a list of futures
- rather than blocking on the results.
- """
- raise NotImplementedError
- @abstractmethod
- def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
- """Wait for futures returned from _run_workers() with
- async_run_remote_workers_only to complete."""
- raise NotImplementedError
- class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
- async def execute_model_async(
- self,
- execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
- if self.parallel_worker_tasks is None:
- # Start model execution loop running in the parallel workers
- self.parallel_worker_tasks = asyncio.create_task(
- self._start_worker_execution_loop())
- # Only the driver worker returns the sampling results.
- return await self._driver_execute_model_async(execute_model_req)
- async def stop_remote_worker_execution_loop_async(self) -> None:
- if self.parallel_worker_tasks is None:
- return
- await self._driver_execute_model_async()
- parallel_worker_tasks = self.parallel_worker_tasks
- self.parallel_worker_tasks = None
- # Ensure that workers exit model loop cleanly
- # (this will raise otherwise)
- await parallel_worker_tasks
- @abstractmethod
- async def _driver_execute_model_async(
- self,
- execute_model_req: Optional[ExecuteModelRequest] = None
- ) -> List[SamplerOutput]:
- """Execute the model asynchronously in the driver worker.
- Passing None will cause the driver to stop the model execution
- loop running in each of the remote workers.
- """
- raise NotImplementedError
- @abstractmethod
- async def _start_worker_execution_loop(self):
- """Run execution loop on all workers. It guarantees all workers run
- the loop or None of them is running the loop. Loop can be stopped by
- `stop_remote_worker_execution_loop`.
- The API is idempotent (guarantee only 1 loop run at any moment)."""
- raise NotImplementedError
|