import asyncio import os import time from functools import partial from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union) from loguru import logger from transformers import PreTrainedTokenizer from aphrodite.common.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput from aphrodite.common.pooling_params import PoolingParams from aphrodite.common.sampling_params import SamplingParams from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput from aphrodite.engine.aphrodite_engine import AphroditeEngine from aphrodite.engine.args_tools import AsyncEngineArgs from aphrodite.engine.async_timeout import asyncio_timeout from aphrodite.engine.metrics import StatLoggerBase from aphrodite.executor.executor_base import ExecutorAsyncBase from aphrodite.executor.ray_utils import initialize_ray_cluster, ray from aphrodite.inputs import LLMInputs, PromptInputs from aphrodite.lora.request import LoRARequest from aphrodite.processing.scheduler import SchedulerOutputs from aphrodite.prompt_adapter.request import PromptAdapterRequest ENGINE_ITERATION_TIMEOUT_S = int( os.environ.get("APHRODITE_ENGINE_ITERATION_TIMEOUT_S", "60")) class AsyncEngineDeadError(RuntimeError): pass def _log_task_completion(task: asyncio.Task, error_callback: Callable[[Exception], None]) -> None: """This function is only intended for the `engine.run_engine_loop()` task. In particular, that task runs a `while True` loop that can only exit if there is an exception. """ exception = None try: return_value = task.result() raise AssertionError( f"The engine background task should never finish without an " f"exception. {return_value}") except asyncio.exceptions.CancelledError: # We assume that if the task is cancelled, we are gracefully shutting # down. This should only happen on program exit. logger.info("Engine is gracefully shutting down.") except Exception as e: exception = e logger.error("Engine background task failed", exc_info=e) error_callback(exception) raise AsyncEngineDeadError( "Task finished unexpectedly. This should never happen! " "Please open an issue on Github. See stack trace above for the" "actual cause.") from e class AsyncStream: """A stream of RequestOutputs or EmbeddingRequestOutputs for a request that can be iterated over asynchronously.""" def __init__(self, request_id: str) -> None: self.request_id = request_id self._queue: asyncio.Queue = asyncio.Queue() self._finished = False def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, Exception]) -> None: if self._finished: return self._queue.put_nowait(item) def finish(self) -> None: self._queue.put_nowait(StopAsyncIteration()) self._finished = True @property def finished(self) -> bool: return self._finished def __aiter__(self): return self async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]: result = await self._queue.get() if isinstance(result, Exception): raise result return result class RequestTracker: """Synchronous abstraction for tracking requests.""" def __init__(self) -> None: self._request_streams: Dict[str, AsyncStream] = {} self._finished_requests: asyncio.Queue[str] = asyncio.Queue() self._new_requests: asyncio.Queue[Tuple[AsyncStream, dict]] = asyncio.Queue() self.new_requests_event = asyncio.Event() def __contains__(self, item): return item in self._request_streams def __len__(self) -> int: return len(self._request_streams) def propagate_exception(self, exc: Exception, request_id: Optional[str] = None) -> None: """Propagate an exception to request streams (all if request_id is None).""" if request_id is not None: self._request_streams[request_id].put(exc) self.abort_request(request_id) else: for rid, stream in self._request_streams.items(): stream.put(exc) self.abort_request(rid) def process_request_output(self, request_output: Union[RequestOutput, EmbeddingRequestOutput], *, verbose: bool = False) -> None: """Process a request output from the engine.""" request_id = request_output.request_id # Guard against a KeyError which can occur if the request was aborted # while the output was generated if (stream := self._request_streams.get(request_id)) is not None: stream.put(request_output) if request_output.finished: if verbose: logger.info(f"Finished request {request_id}.") self.abort_request(request_id) def process_exception(self, request_id: str, exception: Exception, *, verbose: bool = False) -> None: """Propagate an exception from the engine.""" self._request_streams[request_id].put(exception) if verbose: logger.info(f"Finished request {request_id}.") self.abort_request(request_id) def add_request(self, request_id: str, *, verbose: bool = False, **engine_add_request_kwargs) -> AsyncStream: """Add a request to be sent to the engine on the next background loop iteration.""" if request_id in self._request_streams: raise KeyError(f"Request {request_id} already exists.") stream = AsyncStream(request_id) self._new_requests.put_nowait((stream, { "request_id": request_id, **engine_add_request_kwargs })) self.new_requests_event.set() if verbose: logger.info(f"Added request {request_id}.") return stream def abort_request(self, request_id: str, *, verbose: bool = False) -> None: """Abort a request during next background loop iteration.""" if verbose: logger.info(f"Aborted request {request_id}.") self._finished_requests.put_nowait(request_id) if request_id not in self._request_streams or self._request_streams[ request_id].finished: # The request has already finished or been aborted. return self._request_streams[request_id].finish() def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]: """Get the new requests and finished requests to be sent to the engine.""" new_requests: List[Dict] = [] finished_requests: Set[str] = set() while not self._finished_requests.empty(): request_id = self._finished_requests.get_nowait() finished_requests.add(request_id) self._request_streams.pop(request_id, None) while not self._new_requests.empty(): stream, new_request = self._new_requests.get_nowait() if stream.request_id in finished_requests: # The request has already been aborted. stream.finish() continue self._request_streams[stream.request_id] = stream new_requests.append(new_request) return new_requests, finished_requests async def wait_for_new_requests(self): if not self.has_new_requests(): await self.new_requests_event.wait() self.new_requests_event.clear() def has_new_requests(self): return not self._new_requests.empty() class _AsyncAphrodite(AphroditeEngine): """Extension of AphroditeEngine to add async methods.""" async def step_async( self, virtual_engine: int ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. The workers are ran asynchronously if possible. This function performs one decoding iteration of the engine. It first schedules the sequences to be executed in the next iteration and the token blocks to be swapped in/out/copy. Then, it executes the model and updates the scheduler with the model outputs. Finally, it decodes the sequences and returns the newly generated results. """ seq_group_metadata_list, scheduler_outputs = self.scheduler[ virtual_engine].schedule() if not scheduler_outputs.is_empty(): # Execute the model. finished_requests_ids = self.scheduler[ virtual_engine].get_and_reset_finished_requests_ids() execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_copy=scheduler_outputs.blocks_to_copy, virtual_engine=virtual_engine, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, finished_requests_ids=finished_requests_ids, ) output = await self.model_executor.execute_model_async( execute_model_req) else: output = [] request_outputs = self._process_model_outputs( output, scheduler_outputs.scheduled_seq_groups, scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) # Log stats. self.do_log_stats(scheduler_outputs, output) return request_outputs async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" await self.model_executor.stop_remote_worker_execution_loop_async() async def process_model_inputs_async( self, request_id: str, inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: if isinstance(inputs, str): inputs = {"prompt": inputs} if "prompt_token_ids" not in inputs: tokenizer = self.get_tokenizer_group("prompts must be None if " "skip_tokenizer_init is True") prompt_token_ids = await tokenizer.encode_async( request_id=request_id, prompt=inputs["prompt"], lora_request=lora_request) else: prompt_token_ids = inputs["prompt_token_ids"] if prompt_adapter_request: prompt_token_ids = [ 0 ] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \ prompt_token_ids llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, prompt=inputs.get("prompt"), multi_modal_data=inputs.get("multi_modal_data")) return self.input_processor(llm_inputs) async def add_request_async( self, request_id: str, inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") if arrival_time is None: arrival_time = time.time() processed_inputs = await self.process_model_inputs_async( request_id=request_id, inputs=inputs, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) self._add_processed_request( request_id=request_id, processed_inputs=processed_inputs, params=params, arrival_time=arrival_time, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, ) async def check_health_async(self) -> None: if self.tokenizer: self.tokenizer.check_health() self.model_executor.check_health() class AsyncAphrodite: """An asynchronous wrapper for AphroditeEngine. This class is used to wrap the AphroditeEngine class to make it asynchronous. It uses asyncio to create a background loop that keeps processing incoming requests. The AphroditeEngine is kicked by the generate method when there are requests in the waiting queue. The generate method yields the outputs from the AphroditeEngine to the caller. NOTE: For the comprehensive list of arguments, see `AphroditeEngine`. Args: worker_use_ray: Whether to use Ray for model workers. Required for distributed execution. Should be the same as `parallel_config.worker_use_ray`. engine_use_ray: Whether to make AphroditeEngine a Ray actor. If so, the async frontend will be executed in a separate process as the model workers. log_requests: Whether to log the requests. start_engine_loop: If True, the background task to run the engine will be automatically started in the generate call. *args: Arguments for AphroditeEngine. *kwargs: Arguments for AphroditeEngine. """ _engine_class: Type[_AsyncAphrodite] = _AsyncAphrodite def __init__(self, worker_use_ray: bool, engine_use_ray: bool, *args, log_requests: bool = True, start_engine_loop: bool = True, **kwargs) -> None: self.worker_use_ray = worker_use_ray self.engine_use_ray = engine_use_ray self.log_requests = log_requests self.engine = self._init_engine(*args, **kwargs) self.background_loop: Optional[asyncio.Future] = None # We need to keep a reference to unshielded # task as well to prevent it from being garbage # collected self._background_loop_unshielded: Optional[asyncio.Task] = None self.start_engine_loop = start_engine_loop self._errored_with: Optional[BaseException] = None # Lazy initialized fields self._request_tracker: RequestTracker @classmethod def _get_executor_cls( cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]: distributed_executor_backend = ( engine_config.parallel_config.distributed_executor_backend) if isinstance(distributed_executor_backend, type): if not issubclass(distributed_executor_backend, ExecutorAsyncBase): raise TypeError( "distributed_executor_backend must be a subclass of " f"ExecutorAsyncBase. Got {distributed_executor_backend}.") if distributed_executor_backend.uses_ray: # type: ignore initialize_ray_cluster(engine_config.parallel_config) executor_class = distributed_executor_backend elif engine_config.device_config.device_type == "neuron": from aphrodite.executor.neuron_executor import NeuronExecutorAsync executor_class = NeuronExecutorAsync elif engine_config.device_config.device_type == "tpu": if distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) from aphrodite.executor.ray_tpu_executor import ( RayTPUExecutorAsync) executor_class = RayTPUExecutorAsync else: assert distributed_executor_backend is None from aphrodite.executor.tpu_executor import TPUExecutorAsync executor_class = TPUExecutorAsync elif engine_config.device_config.device_type == "cpu": from aphrodite.executor.cpu_executor import CPUExecutorAsync executor_class = CPUExecutorAsync elif engine_config.device_config.device_type == "openvino": assert distributed_executor_backend is None, ( "Distributed execution is not supported with the OpenVINO " "backend.") from aphrodite.executor.openvino_executor import ( OpenVINOExecutorAsync) executor_class = OpenVINOExecutorAsync elif engine_config.device_config.device_type == "xpu": if distributed_executor_backend is None: from aphrodite.executor.xpu_executor import XPUExecutorAsync executor_class = XPUExecutorAsync elif distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) from aphrodite.executor.ray_xpu_executor import ( RayXPUExecutorAsync) executor_class = RayXPUExecutorAsync else: raise RuntimeError( "Unsupported distributed executor backend for XPU.") elif distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) from aphrodite.executor.ray_gpu_executor import RayGPUExecutorAsync executor_class = RayGPUExecutorAsync elif distributed_executor_backend == "mp": from aphrodite.executor.multiproc_gpu_executor import ( MultiprocessingGPUExecutorAsync) executor_class = MultiprocessingGPUExecutorAsync else: from aphrodite.executor.gpu_executor import GPUExecutorAsync executor_class = GPUExecutorAsync return executor_class @classmethod def from_engine_args( cls, engine_args: AsyncEngineArgs, start_engine_loop: bool = True, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, ) -> "AsyncAphrodite": """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. engine_config = engine_args.create_engine_config() if engine_args.engine_use_ray: from aphrodite.executor import ray_utils ray_utils.assert_ray_available() executor_class = cls._get_executor_cls(engine_config) # Create the async LLM engine. engine = cls( executor_class.uses_ray, engine_args.engine_use_ray, **engine_config.to_dict(), executor_class=executor_class, log_requests=not engine_args.disable_log_requests, log_stats=not engine_args.disable_log_stats, start_engine_loop=start_engine_loop, stat_loggers=stat_loggers, ) return engine @property def is_running(self) -> bool: return (self.background_loop is not None and self._background_loop_unshielded is not None and not self._background_loop_unshielded.done()) @property def is_stopped(self) -> bool: return self.errored or (self.background_loop is not None and self._background_loop_unshielded is not None and self._background_loop_unshielded.done()) @property def errored(self) -> bool: return self._errored_with is not None def set_errored(self, exc: Exception) -> None: self._errored_with = exc def _error_callback(self, exc: Exception) -> None: self.set_errored(exc) self._request_tracker.propagate_exception(exc) async def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, ) -> "PreTrainedTokenizer": if self.engine_use_ray: return await self.engine.get_tokenizer.remote( # type: ignore lora_request) return await (self.engine.get_tokenizer_group(). get_lora_tokenizer_async(lora_request)) def start_background_loop(self) -> None: """Start the background loop.""" if self.errored: raise AsyncEngineDeadError( "Background loop has errored already.") from self._errored_with if self.is_running: raise RuntimeError("Background loop is already running.") # Initialize the RequestTracker here so it uses the right event loop. self._request_tracker = RequestTracker() self._background_loop_unshielded = asyncio.get_event_loop( ).create_task(self.run_engine_loop()) self._background_loop_unshielded.add_done_callback( partial(_log_task_completion, error_callback=self._error_callback)) self.background_loop = asyncio.shield(self._background_loop_unshielded) def _init_engine(self, *args, **kwargs) -> Union[_AsyncAphrodite, "ray.ObjectRef"]: if not self.engine_use_ray: engine_class = self._engine_class elif self.worker_use_ray: engine_class = ray.remote(num_cpus=0)(self._engine_class).remote else: # FIXME: This is a bit hacky. Be careful when changing the # order of the arguments. cache_config = kwargs["cache_config"] parallel_config = kwargs["parallel_config"] if (parallel_config.tensor_parallel_size == 1 and parallel_config.pipeline_parallel_size == 1): num_gpus = cache_config.gpu_memory_utilization else: num_gpus = 1 engine_class = ray.remote(num_gpus=num_gpus)( self._engine_class).remote return engine_class(*args, **kwargs) async def engine_step(self, virtual_engine: int) -> bool: """Kick the engine to process the waiting requests. Returns True if there are in-progress requests.""" new_requests, finished_requests = ( self._request_tracker.get_new_and_finished_requests()) for new_request in new_requests: # Add the request into the Aphrodite engine's waiting queue. # TODO: Maybe add add_request_batch to reduce Ray overhead try: if self.engine_use_ray: await self.engine.add_request.remote( # type: ignore **new_request) else: await self.engine.add_request_async(**new_request) except ValueError as e: # TODO: use an Aphrodite specific error for failed validation self._request_tracker.process_exception( new_request["request_id"], e, verbose=self.log_requests, ) if finished_requests: await self._engine_abort(finished_requests) if self.engine_use_ray: request_outputs = await self.engine.step.remote() # type: ignore else: request_outputs = await self.engine.step_async(virtual_engine) # Put the outputs into the corresponding streams. finished = True for request_output in request_outputs: self._request_tracker.process_request_output( request_output, verbose=self.log_requests) finished = finished and request_output.finished return not finished async def _engine_abort(self, request_ids: Iterable[str]): if self.engine_use_ray: await self.engine.abort_request.remote(request_ids) # type: ignore else: self.engine.abort_request(request_ids) async def run_engine_loop(self): if self.engine_use_ray: pipeline_parallel_size = 1 # type: ignore else: pipeline_parallel_size = \ self.engine.parallel_config.pipeline_parallel_size has_requests_in_progress = [False] * pipeline_parallel_size while True: if not any(has_requests_in_progress): logger.debug("Waiting for new requests...") # Stop the execute model loop in parallel workers until there # are more requests to process. This avoids waiting # indefinitely in torch.distributed ops which may otherwise # timeout, and unblocks the RPC thread in the workers so that # they can process any other queued control plane messages, # such as add/remove lora adapters. if self.engine_use_ray: await (self.engine.stop_remote_worker_execution_loop. remote() # type: ignore ) else: await self.engine.stop_remote_worker_execution_loop_async() await self._request_tracker.wait_for_new_requests() logger.debug("Got new requests!") requests_in_progress = [ asyncio.create_task(self.engine_step(ve)) for ve in range(pipeline_parallel_size) ] has_requests_in_progress = [True] * pipeline_parallel_size # Abort if iteration takes too long due to unrecoverable errors # (eg. NCCL timeouts). try: async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S): done, _ = await asyncio.wait( requests_in_progress, return_when=asyncio.FIRST_COMPLETED) for _ in range(pipeline_parallel_size): await asyncio.sleep(0) for task in done: result = task.result() virtual_engine = requests_in_progress.index(task) if self.engine_use_ray: has_unfinished_requests = ( await (self.engine. has_unfinished_requests_for_virtual_engine. remote( # type: ignore virtual_engine))) else: has_unfinished_requests = ( self.engine. has_unfinished_requests_for_virtual_engine( virtual_engine)) if result or has_unfinished_requests: requests_in_progress[virtual_engine] = ( asyncio.create_task( self.engine_step(virtual_engine))) has_requests_in_progress[virtual_engine] = True else: has_requests_in_progress[virtual_engine] = False except asyncio.TimeoutError as exc: logger.error( "Engine iteration timed out. This should never happen!") self.set_errored(exc) raise await asyncio.sleep(0) async def add_request( self, request_id: str, inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> AsyncStream: if not self.is_running: if self.start_engine_loop: self.start_background_loop() else: raise AsyncEngineDeadError( "Background loop is not running. If it was running, " "inspect the output to find the stacktrace of the " "error that caused the background loop to stop " "(AsyncEngineDeadError).") if arrival_time is None: arrival_time = time.time() stream = self._request_tracker.add_request( request_id, verbose=self.log_requests, inputs=inputs, params=params, arrival_time=arrival_time, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) return stream async def generate( self, inputs: PromptInputs, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. Generate outputs for a request. This method is a coroutine. It adds the request into the waiting queue of the AphroditeEngine and streams the outputs from the AphroditeEngine to the caller. Args: prompt: The prompt string. Can be None if prompt_token_ids is provided. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. prompt_token_ids: The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs. lora_request: LoRA request to use for generation, if any. prompt_adapter_request: Prompt Adapter request to use for generation, if any. Yields: The output `RequestOutput` objects from the AphroditeEngine for the request. Details: - If the engine is not running, start the background loop, which iteratively invokes # pylint: disable=line-too-long :meth:`~aphrodite.engine.async_aphrodite.AsyncAphrodite.engine_step` to process the waiting requests. - Add the request to the engine's `RequestTracker`. On the next background loop, this request will be sent to the underlying engine. Also, a corresponding `AsyncStream` will be created. - Wait for the request outputs from `AsyncStream` and yield them. Example: >>> # Please refer to entrypoints/api_server.py for >>> # the complete example. >>> >>> # initialize the engine and the example input >>> engine = AsyncAphrodite.from_engine_args(engine_args) >>> example_input = { >>> "prompt": "What is LLM?", >>> "stream": False, # assume the non-streaming case >>> "temperature": 0.0, >>> "request_id": 0, >>> } >>> >>> # start the generation >>> results_generator = engine.generate( >>> example_input["prompt"], >>> SamplingParams(temperature=example_input["temperature"]), >>> example_input["request_id"]) >>> >>> # get the results >>> final_output = None >>> async for request_output in results_generator: >>> if await request.is_disconnected(): >>> # Abort the request if the client disconnects. >>> await engine.abort(request_id) >>> # Return or raise an error >>> ... >>> final_output = request_output >>> >>> # Process and return the final output >>> ... """ async for output in self._process_request( request_id, inputs, sampling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, ): yield AphroditeEngine.validate_output(output, RequestOutput) async def encode( self, inputs: PromptInputs, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> AsyncIterator[EmbeddingRequestOutput]: """Generate outputs for a request from an embedding model. Generate outputs for a request. This method is a coroutine. It adds the request into the waiting queue of the AphroditeEngine and streams the outputs from the AphroditeEngine to the caller. Args: prompt: The prompt string. Can be None if prompt_token_ids is provided. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. prompt_token_ids: The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs. lora_request: LoRA request to use for generation, if any. multi_modal_data: Multi modal data per request. Yields: The output `EmbeddingRequestOutput` objects from the AphroditeEngine for the request. Details: - If the engine is not running, start the background loop, which iteratively invokes :meth:`~aphrodite.engine.async_aphrodite.AsyncAphrodite.engine_step` to process the waiting requests. - Add the request to the engine's `RequestTracker`. On the next background loop, this request will be sent to the underlying engine. Also, a corresponding `AsyncStream` will be created. - Wait for the request outputs from `AsyncStream` and yield them. Example: >>> # initialize the engine and the example input >>> engine = AsyncAphrodite.from_engine_args(engine_args) >>> example_input = { >>> "input": "What is LLM?", >>> "request_id": 0, >>> } >>> >>> # start the generation >>> results_generator = engine.encode( >>> example_input["input"], >>> PoolingParams(), >>> example_input["request_id"]) >>> >>> # get the results >>> final_output = None >>> async for request_output in results_generator: >>> if await request.is_disconnected(): >>> # Abort the request if the client disconnects. >>> await engine.abort(request_id) >>> # Return or raise an error >>> ... >>> final_output = request_output >>> >>> # Process and return the final output >>> ... """ async for output in self._process_request( request_id, inputs, pooling_params, lora_request=lora_request, ): yield AphroditeEngine.validate_output(output, EmbeddingRequestOutput) async def _process_request( self, request_id: str, inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], *, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: """Common logic to process requests with SamplingParams or PoolingParams.""" arrival_time = time.time() stream = await self.add_request( request_id, inputs, params, arrival_time=arrival_time, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) try: async for request_output in stream: yield request_output except (Exception, asyncio.CancelledError) as e: self._abort(request_id) raise e async def abort(self, request_id: str) -> None: """Abort a request. Abort a submitted request. If the request is finished or not found, this method will be a no-op. Args: request_id: The unique id of the request. """ if not self.is_running: raise AsyncEngineDeadError( "Background loop is not running. If it was running, " "inspect the output to find the stacktrace of the " "error that caused the background loop to stop " "(AsyncEngineDeadError).") return self._abort(request_id) def _abort(self, request_id: str) -> None: """Abort a request. Abort a submitted request. If the request is finished or not found, this method will be a no-op. Args: request_id: The unique id of the request. """ self._request_tracker.abort_request(request_id, verbose=self.log_requests) async def get_model_config(self) -> ModelConfig: """Get the model configuration of the Aphrodite engine.""" if self.engine_use_ray: return await self.engine.get_model_config.remote() # type: ignore else: return self.engine.get_model_config() async def get_parallel_config(self) -> ParallelConfig: """Get the parallel configuration of the Aphrodite engine.""" if self.engine_use_ray: return await self.engine.get_parallel_config.remote( # type: ignore ) else: return self.engine.get_parallel_config() async def get_decoding_config(self) -> DecodingConfig: """Get the decoding configuration of the Aphrodite engine.""" if self.engine_use_ray: return await self.engine.get_decoding_config.remote( # type: ignore ) else: return self.engine.get_decoding_config() async def get_scheduler_config(self) -> SchedulerConfig: """Get the scheduling configuration of the Aphrodite engine.""" if self.engine_use_ray: return await self.engine.get_scheduler_config.remote( # type: ignore ) else: return self.engine.get_scheduler_config() async def get_lora_config(self) -> LoRAConfig: """Get the lora configuration of the Aphrodite engine.""" if self.engine_use_ray: return await self.engine.get_lora_config.remote( # type: ignore ) else: return self.engine.get_lora_config() async def do_log_stats( self, scheduler_outputs: Optional[SchedulerOutputs] = None, model_output: Optional[List[SamplerOutput]] = None) -> None: if self.engine_use_ray: await self.engine.do_log_stats.remote( # type: ignore scheduler_outputs, model_output) else: self.engine.do_log_stats() async def check_health(self) -> None: """Raises an error if engine is unhealthy.""" t = time.perf_counter() logger.debug("Starting health check...") if self.is_stopped: raise AsyncEngineDeadError("Background loop is stopped.") if self.engine_use_ray: try: await self.engine.check_health.remote() # type: ignore except ray.exceptions.RayActorError as e: raise RuntimeError("Engine is dead.") from e else: await self.engine.check_health_async() logger.debug(f"Health check took {time.perf_counter()-t}s")