import asyncio import signal from typing import Any, Coroutine import cloudpickle import zmq import zmq.asyncio from loguru import logger from typing_extensions import Never from aphrodite import AsyncAphrodite, AsyncEngineArgs from aphrodite.common.utils import in_windows from aphrodite.endpoints.openai.rpc import (APHRODITE_RPC_HEALTHY_STR, APHRODITE_RPC_SUCCESS_STR, RPCAbortRequest, RPCGenerateRequest, RPCUtilityRequest) if in_windows(): import winloop as uvloop else: import uvloop class AsyncEngineRPCServer: def __init__(self, async_engine_args: AsyncEngineArgs, rpc_path: str): # Initialize engine first. self.engine = AsyncAphrodite.from_engine_args(async_engine_args) # Initialize context. self.context = zmq.asyncio.Context() # Init socket for readiness state. self.socket = self.context.socket(zmq.constants.ROUTER) self.socket.bind(rpc_path) def cleanup(self): """Cleanup all resources.""" self.socket.close() self.context.destroy() self.engine.shutdown_background_loop() # Clear the engine reference so that it can be GC'ed. self.engine = None async def get_model_config(self, identity): """Send the ModelConfig""" model_config = await self.engine.get_model_config() await self.socket.send_multipart( [identity, cloudpickle.dumps(model_config)]) async def get_decoding_config(self, identity): """Send the DecodingConfig""" decoding_config = await self.engine.get_decoding_config() await self.socket.send_multipart( [identity, cloudpickle.dumps(decoding_config)]) async def get_lora_config(self, identity): lora_config = await self.engine.get_lora_config() await self.socket.send_multipart( [identity, cloudpickle.dumps(lora_config)]) async def get_scheduler_config(self, identity): """Send the SchedulerConfig""" parallel_config = await self.engine.get_scheduler_config() await self.socket.send_multipart( [identity, cloudpickle.dumps(parallel_config)]) async def get_parallel_config(self, identity): """Send the ParallelConfig""" parallel_config = await self.engine.get_parallel_config() await self.socket.send_multipart( [identity, cloudpickle.dumps(parallel_config)]) async def do_log_stats(self, identity): """Log stats and confirm success.""" await self.engine.do_log_stats() await self.socket.send_multipart([ identity, cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR), ]) async def is_server_ready(self, identity): """Notify the client that we are ready.""" await self.socket.send_multipart([ identity, cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR), ]) async def abort(self, identity, request: RPCAbortRequest): """Abort request and notify the client of success.""" try: # Abort the request in the llm engine. await self.engine.abort(request.request_id) except Exception: logger.warning(f"Failed to abort request {request.request_id}") finally: # Send confirmation to the client. await self.socket.send_multipart([ identity, cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR), ]) async def generate(self, identity, generate_request: RPCGenerateRequest): try: results_generator = self.engine.generate( generate_request.inputs, sampling_params=generate_request.sampling_params, request_id=generate_request.request_id, lora_request=generate_request.lora_request, prompt_adapter_request=generate_request.prompt_adapter_request) async for request_output in results_generator: await self.socket.send_multipart( [identity, cloudpickle.dumps(request_output)]) except Exception as e: ### Notify client of all failures await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) async def check_health(self, identity): try: await self.engine.check_health() await self.socket.send_multipart( [identity, cloudpickle.dumps(APHRODITE_RPC_HEALTHY_STR)]) except Exception as e: await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) def _make_handler_coro(self, identity, message) -> Coroutine[Any, Any, Never]: """Route the zmq message to the handler coroutine.""" request = cloudpickle.loads(message) if isinstance(request, RPCGenerateRequest): return self.generate(identity, request) elif isinstance(request, RPCAbortRequest): return self.abort(identity, request) elif isinstance(request, RPCUtilityRequest): if request == RPCUtilityRequest.GET_MODEL_CONFIG: return self.get_model_config(identity) elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: return self.get_parallel_config(identity) elif request == RPCUtilityRequest.GET_DECODING_CONFIG: return self.get_decoding_config(identity) elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: return self.get_scheduler_config(identity) elif request == RPCUtilityRequest.GET_LORA_CONFIG: return self.get_lora_config(identity) elif request == RPCUtilityRequest.DO_LOG_STATS: return self.do_log_stats(identity) elif request == RPCUtilityRequest.IS_SERVER_READY: return self.is_server_ready(identity) elif request == RPCUtilityRequest.CHECK_HEALTH: return self.check_health(identity) else: raise ValueError(f"Unknown RPCUtilityRequest type: {request}") else: raise ValueError(f"Unknown RPCRequest type: {request}") async def run_server_loop(self): """Inner RPC Server Loop""" running_tasks = set() while True: # Wait for a request. identity, message = await self.socket.recv_multipart() # Process the request async. task = asyncio.create_task( self._make_handler_coro(identity, message)) # We need to keep around a strong reference to the task, # to avoid the task disappearing mid-execution as running tasks # can be GC'ed. Below is a common "fire-and-forget" tasks # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task running_tasks.add(task) task.add_done_callback(running_tasks.discard) async def run_server(server: AsyncEngineRPCServer): # Put the server task into the asyncio loop. loop = asyncio.get_running_loop() server_task = loop.create_task(server.run_server_loop()) # Interruption handling. def signal_handler() -> None: # Kill the server on interrupt / terminate server_task.cancel() loop.add_signal_handler(signal.SIGINT, signal_handler) loop.add_signal_handler(signal.SIGTERM, signal_handler) try: await server_task except asyncio.CancelledError: logger.info("Aphrodite ZMQ RPC Server was interrupted.") finally: # Clean up all resources. server.cleanup() def run_rpc_server(async_engine_args: AsyncEngineArgs, rpc_path: str): server = AsyncEngineRPCServer(async_engine_args, rpc_path) uvloop.run(run_server(server))