from contextlib import contextmanager from typing import Any, AsyncGenerator, Optional import cloudpickle import zmq import zmq.asyncio from aphrodite.common.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput from aphrodite.common.sampling_params import SamplingParams from aphrodite.endpoints.openai.rpc import (APHRODITE_RPC_HEALTHY_STR, APHRODITE_RPC_SUCCESS_STR, RPC_REQUEST_TYPE, RPCAbortRequest, RPCGenerateRequest, RPCUtilityRequest) from aphrodite.inputs import PromptInputs from aphrodite.lora.request import LoRARequest from aphrodite.prompt_adapter.request import PromptAdapterRequest from aphrodite.transformers_utils.tokenizer_group import ( init_tokenizer_from_configs) # Time to wait before checking if the server process is alive SERVER_START_TIMEOUT_MS = 1000 class AsyncEngineRPCClient: def __init__(self, rpc_path: str): self.context = zmq.asyncio.Context() self.rpc_path = rpc_path async def setup(self): """Setup the client before it starts sending server requests.""" # Wait until server is ready. await self.wait_for_server() self._errored = False # Get the configs. self.model_config = await self._get_model_config_rpc() self.decoding_config = await self._get_decoding_config_rpc() # Create the tokenizer group. # TODO: refactor OAI server to avoid needing this info. self.tokenizer = init_tokenizer_from_configs( model_config=self.model_config, scheduler_config=(await self._get_scheduler_config_rpc()), parallel_config=(await self._get_parallel_config_rpc()), enable_lora=bool(await self._get_lora_config_rpc()), ) def close(self): """Destroy the ZeroMQ Context.""" self.context.destroy() @contextmanager def socket(self): # Ensure client sockets are always closed after use # Connect to RPC socket for Request-Reply pattern, # Note that we use DEALER to enable asynchronous communication # to enable streaming. socket = self.context.socket(zmq.constants.DEALER) try: socket.connect(self.rpc_path) yield socket finally: # linger == 0 means discard unsent messages # when the socket is closed. This is necessary # because otherwise self.context.destroy() will # wait for 30 seconds until unsent messages are # received, which is impossible if the server # crashed. In the absence of a server crash we # always expect a response before closing the # socket anyway. # Reference: http://api.zeromq.org/4-2:zmq-setsockopt#toc24 socket.close(linger=0) async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, expected_type: Any, error_message: str) -> Any: """Send an RPC request that is expecting data back.""" with self.socket() as socket: # Ping RPCServer with a request. await socket.send(cloudpickle.dumps(request)) # Await the data from the Server. data = cloudpickle.loads(await socket.recv()) if not isinstance(data, expected_type): # LoRAConfig can be None. if expected_type == LoRAConfig and data is None: pass else: raise ValueError(error_message) return data async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, error_message: str): """Send one-way RPC request to trigger an action.""" with self.socket() as socket: # Ping RPC Server with request. await socket.send(cloudpickle.dumps(request)) # Await acknowledgement from RPCServer. response = cloudpickle.loads(await socket.recv()) if not isinstance(response, str) or response != \ APHRODITE_RPC_SUCCESS_STR: raise ValueError(error_message) return response async def get_tokenizer(self, lora_request: LoRARequest): return await self.tokenizer.get_lora_tokenizer_async(lora_request) async def get_decoding_config(self) -> DecodingConfig: return self.decoding_config async def get_model_config(self) -> ModelConfig: return self.model_config async def wait_for_server(self): """Wait for the RPCServer to start up.""" await self._send_one_way_rpc_request( request=RPCUtilityRequest.IS_SERVER_READY, error_message="Unable to start RPC Server.") async def _get_model_config_rpc(self) -> ModelConfig: """Get the ModelConfig object from the RPC Server""" return await self._send_get_data_rpc_request( RPCUtilityRequest.GET_MODEL_CONFIG, expected_type=ModelConfig, error_message="Could not get ModelConfig from RPC Server") async def _get_decoding_config_rpc(self) -> DecodingConfig: """Get DecodingConfig from the RPCServer""" return await self._send_get_data_rpc_request( RPCUtilityRequest.GET_DECODING_CONFIG, expected_type=DecodingConfig, error_message="Could not get DecodingConfig from RPC Server") async def _get_parallel_config_rpc(self) -> ParallelConfig: """Get ParallelConfig from the RPCServer""" return await self._send_get_data_rpc_request( RPCUtilityRequest.GET_PARALLEL_CONFIG, expected_type=ParallelConfig, error_message="Could not get ParallelConfig from RPC Server") async def _get_scheduler_config_rpc(self) -> SchedulerConfig: """Get SchedulerConfig from the RPCServer""" return await self._send_get_data_rpc_request( RPCUtilityRequest.GET_SCHEDULER_CONFIG, expected_type=SchedulerConfig, error_message="Could not get SchedulerConfig from RPC Server") async def _get_lora_config_rpc(self) -> LoRAConfig: """Get LoRAConfig from the RPCServer""" return await self._send_get_data_rpc_request( RPCUtilityRequest.GET_LORA_CONFIG, expected_type=LoRAConfig, error_message="Could not get LoRAConfig from RPC Server") async def abort(self, request_id: str): """Send an ABORT_REQUEST signal to the RPC Server""" await self._send_one_way_rpc_request( request=RPCAbortRequest(request_id), error_message=f"RPCAbortRequest {request_id} failed") async def do_log_stats(self): """Send a DO_LOG_STATS signal to the RPC Server""" await self._send_one_way_rpc_request( request=RPCUtilityRequest.DO_LOG_STATS, error_message="RPCRequest DO_LOG_STATS failed.") @property def is_running(self) -> bool: return not self._errored @property def is_stopped(self) -> bool: return self._errored @property def errored(self) -> bool: return self._errored async def generate( self, inputs: PromptInputs, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncGenerator[RequestOutput, None]: """Send an RPCGenerateRequest to the RPCServer and stream responses.""" finished = False try: with self.socket() as socket: # Send RPCGenerateRequest to the RPCServer. await socket.send_multipart([ cloudpickle.dumps( RPCGenerateRequest( inputs=inputs, sampling_params=sampling_params, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request)) ]) # Stream back the results from the RPC Server. while not finished: message = await socket.recv() request_output = cloudpickle.loads(message) if isinstance(request_output, Exception): # On exception, check if the server is still healthy. # Use this to set the sync `is_running` and `errored` # properties. try: await self.check_health() except Exception: self._errored = True # NB: do before raising here so that the flag is set # by the time the caller receives this exception raise request_output finished = request_output.finished yield request_output finally: if not finished: await self.abort(request_id) async def check_health(self) -> None: """Raise if unhealthy""" with self.socket() as socket: # Ping RPCServer with CHECK_HEALTH request. await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH) ) # Await the reply from the server. # TODO: do we need an internal timeout here? # Or do we expect the external probe to timeout and let this chill? health_message = cloudpickle.loads(await socket.recv()) if isinstance(health_message, Exception): raise health_message if health_message != APHRODITE_RPC_HEALTHY_STR: raise ValueError("Expected healthy response from backend but got " f"{health_message}") async def encode(self, *args, **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: raise NotImplementedError( "Embeddings not supported with multiprocessing backend")