123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271 |
- from contextlib import contextmanager
- from typing import Any, AsyncGenerator, Optional
- import cloudpickle
- import zmq
- import zmq.asyncio
- from aphrodite.common.config import (DecodingConfig, LoRAConfig, ModelConfig,
- ParallelConfig, SchedulerConfig)
- from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
- from aphrodite.common.sampling_params import SamplingParams
- from aphrodite.endpoints.openai.rpc import (APHRODITE_RPC_HEALTHY_STR,
- APHRODITE_RPC_SUCCESS_STR,
- RPC_REQUEST_TYPE, RPCAbortRequest,
- RPCGenerateRequest,
- RPCUtilityRequest)
- from aphrodite.inputs import PromptInputs
- from aphrodite.lora.request import LoRARequest
- from aphrodite.prompt_adapter.request import PromptAdapterRequest
- from aphrodite.transformers_utils.tokenizer_group import (
- init_tokenizer_from_configs)
- # Time to wait before checking if the server process is alive
- SERVER_START_TIMEOUT_MS = 1000
- class AsyncEngineRPCClient:
- def __init__(self, rpc_path: str):
- self.context = zmq.asyncio.Context()
- self.rpc_path = rpc_path
- async def setup(self):
- """Setup the client before it starts sending server requests."""
- # Wait until server is ready.
- await self.wait_for_server()
- self._errored = False
- # Get the configs.
- self.model_config = await self._get_model_config_rpc()
- self.decoding_config = await self._get_decoding_config_rpc()
- # Create the tokenizer group.
- # TODO: refactor OAI server to avoid needing this info.
- self.tokenizer = init_tokenizer_from_configs(
- model_config=self.model_config,
- scheduler_config=(await self._get_scheduler_config_rpc()),
- parallel_config=(await self._get_parallel_config_rpc()),
- enable_lora=bool(await self._get_lora_config_rpc()),
- )
- def close(self):
- """Destroy the ZeroMQ Context."""
- self.context.destroy()
- @contextmanager
- def socket(self):
- # Ensure client sockets are always closed after use
- # Connect to RPC socket for Request-Reply pattern,
- # Note that we use DEALER to enable asynchronous communication
- # to enable streaming.
- socket = self.context.socket(zmq.constants.DEALER)
- try:
- socket.connect(self.rpc_path)
- yield socket
- finally:
- # linger == 0 means discard unsent messages
- # when the socket is closed. This is necessary
- # because otherwise self.context.destroy() will
- # wait for 30 seconds until unsent messages are
- # received, which is impossible if the server
- # crashed. In the absence of a server crash we
- # always expect a response before closing the
- # socket anyway.
- # Reference: http://api.zeromq.org/4-2:zmq-setsockopt#toc24
- socket.close(linger=0)
- async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
- expected_type: Any,
- error_message: str) -> Any:
- """Send an RPC request that is expecting data back."""
- with self.socket() as socket:
- # Ping RPCServer with a request.
- await socket.send(cloudpickle.dumps(request))
- # Await the data from the Server.
- data = cloudpickle.loads(await socket.recv())
- if not isinstance(data, expected_type):
- # LoRAConfig can be None.
- if expected_type == LoRAConfig and data is None:
- pass
- else:
- raise ValueError(error_message)
- return data
- async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
- error_message: str):
- """Send one-way RPC request to trigger an action."""
- with self.socket() as socket:
- # Ping RPC Server with request.
- await socket.send(cloudpickle.dumps(request))
- # Await acknowledgement from RPCServer.
- response = cloudpickle.loads(await socket.recv())
- if not isinstance(response, str) or response != \
- APHRODITE_RPC_SUCCESS_STR:
- raise ValueError(error_message)
- return response
- async def get_tokenizer(self, lora_request: LoRARequest):
- return await self.tokenizer.get_lora_tokenizer_async(lora_request)
- async def get_decoding_config(self) -> DecodingConfig:
- return self.decoding_config
- async def get_model_config(self) -> ModelConfig:
- return self.model_config
- async def wait_for_server(self):
- """Wait for the RPCServer to start up."""
- await self._send_one_way_rpc_request(
- request=RPCUtilityRequest.IS_SERVER_READY,
- error_message="Unable to start RPC Server.")
- async def _get_model_config_rpc(self) -> ModelConfig:
- """Get the ModelConfig object from the RPC Server"""
- return await self._send_get_data_rpc_request(
- RPCUtilityRequest.GET_MODEL_CONFIG,
- expected_type=ModelConfig,
- error_message="Could not get ModelConfig from RPC Server")
- async def _get_decoding_config_rpc(self) -> DecodingConfig:
- """Get DecodingConfig from the RPCServer"""
- return await self._send_get_data_rpc_request(
- RPCUtilityRequest.GET_DECODING_CONFIG,
- expected_type=DecodingConfig,
- error_message="Could not get DecodingConfig from RPC Server")
- async def _get_parallel_config_rpc(self) -> ParallelConfig:
- """Get ParallelConfig from the RPCServer"""
- return await self._send_get_data_rpc_request(
- RPCUtilityRequest.GET_PARALLEL_CONFIG,
- expected_type=ParallelConfig,
- error_message="Could not get ParallelConfig from RPC Server")
- async def _get_scheduler_config_rpc(self) -> SchedulerConfig:
- """Get SchedulerConfig from the RPCServer"""
- return await self._send_get_data_rpc_request(
- RPCUtilityRequest.GET_SCHEDULER_CONFIG,
- expected_type=SchedulerConfig,
- error_message="Could not get SchedulerConfig from RPC Server")
- async def _get_lora_config_rpc(self) -> LoRAConfig:
- """Get LoRAConfig from the RPCServer"""
- return await self._send_get_data_rpc_request(
- RPCUtilityRequest.GET_LORA_CONFIG,
- expected_type=LoRAConfig,
- error_message="Could not get LoRAConfig from RPC Server")
- async def abort(self, request_id: str):
- """Send an ABORT_REQUEST signal to the RPC Server"""
- await self._send_one_way_rpc_request(
- request=RPCAbortRequest(request_id),
- error_message=f"RPCAbortRequest {request_id} failed")
- async def do_log_stats(self):
- """Send a DO_LOG_STATS signal to the RPC Server"""
- await self._send_one_way_rpc_request(
- request=RPCUtilityRequest.DO_LOG_STATS,
- error_message="RPCRequest DO_LOG_STATS failed.")
- @property
- def is_running(self) -> bool:
- return not self._errored
- @property
- def is_stopped(self) -> bool:
- return self._errored
- @property
- def errored(self) -> bool:
- return self._errored
- async def generate(
- self,
- inputs: PromptInputs,
- sampling_params: SamplingParams,
- request_id: str,
- lora_request: Optional[LoRARequest] = None,
- prompt_adapter_request: Optional[PromptAdapterRequest] = None
- ) -> AsyncGenerator[RequestOutput, None]:
- """Send an RPCGenerateRequest to the RPCServer and stream responses."""
- finished = False
- try:
- with self.socket() as socket:
- # Send RPCGenerateRequest to the RPCServer.
- await socket.send_multipart([
- cloudpickle.dumps(
- RPCGenerateRequest(
- inputs=inputs,
- sampling_params=sampling_params,
- request_id=request_id,
- lora_request=lora_request,
- prompt_adapter_request=prompt_adapter_request))
- ])
- # Stream back the results from the RPC Server.
- while not finished:
- message = await socket.recv()
- request_output = cloudpickle.loads(message)
- if isinstance(request_output, Exception):
- # On exception, check if the server is still healthy.
- # Use this to set the sync `is_running` and `errored`
- # properties.
- try:
- await self.check_health()
- except Exception:
- self._errored = True
- # NB: do before raising here so that the flag is set
- # by the time the caller receives this exception
- raise request_output
- finished = request_output.finished
- yield request_output
- finally:
- if not finished:
- await self.abort(request_id)
- async def check_health(self) -> None:
- """Raise if unhealthy"""
- with self.socket() as socket:
- # Ping RPCServer with CHECK_HEALTH request.
- await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH)
- )
- # Await the reply from the server.
- # TODO: do we need an internal timeout here?
- # Or do we expect the external probe to timeout and let this chill?
- health_message = cloudpickle.loads(await socket.recv())
- if isinstance(health_message, Exception):
- raise health_message
- if health_message != APHRODITE_RPC_HEALTHY_STR:
- raise ValueError("Expected healthy response from backend but got "
- f"{health_message}")
- async def encode(self, *args,
- **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
- raise NotImplementedError(
- "Embeddings not supported with multiprocessing backend")
|