123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489 |
- import asyncio
- import copy
- import pickle
- from contextlib import contextmanager, suppress
- from typing import Any, AsyncGenerator, Dict, Iterator, Optional, Union
- import cloudpickle
- import zmq
- import zmq.asyncio
- from loguru import logger
- from zmq import Frame # type: ignore[attr-defined]
- from zmq.asyncio import Socket
- from aphrodite import PoolingParams
- from aphrodite.common.config import DecodingConfig, EngineConfig, ModelConfig
- from aphrodite.common.envs import APHRODITE_RPC_TIMEOUT
- from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
- from aphrodite.common.sampling_params import SamplingParams
- from aphrodite.engine.args_tools import AsyncEngineArgs
- from aphrodite.engine.multiprocessing import (APHRODITE_RPC_SUCCESS_STR,
- ENGINE_DEAD_ERROR, IPC_DATA_EXT,
- IPC_HEALTH_EXT, IPC_INPUT_EXT,
- IPC_OUTPUT_EXT, RPC_REQUEST_T,
- RPCAbortRequest, RPCError,
- RPCHealthRequest,
- RPCProcessRequest,
- RPCStartupRequest,
- RPCStartupResponse)
- from aphrodite.inputs import PromptInputs
- from aphrodite.lora.request import LoRARequest
- from aphrodite.prompt_adapter.request import PromptAdapterRequest
- from aphrodite.transformers_utils.tokenizer_group import (
- init_tokenizer_from_configs)
- class MQClientClosedError(Exception):
- """Exception class raised when the client is used post-close.
-
- The client can be closed, which closes the ZMQ context. This normally
- happens on server shutdown. In some cases, methods like abort and
- do_log_stats will still be called and then try to open a socket, which
- causes a ZMQError and creates a huge stack trace.
- So, we throw this error such that we can suppress it.
- """
- class MQAphroditeEngineClient:
- """A client wrapper for MQAphroditeEngine that conforms to the
- EngineClient protocol.
- MQAphroditeEngine and MQAphroditeEngineClient are intended to run in
- separate processes communicating via zeromq ipc sockets.
- The entrypoint to MQAphroditeEngineClient is through the generate()
- method. On generate() MQAphroditeEngine does three things:
- - Creates an asyncio output queue
- - Sends a RPCGenerateRequest to the MQAphroditeEngine via zmq
- - Pulls RequestOutputs from its queue and yields them
- MQAphroditeEngine runs two background loops:
- - output_loop: the output loop pulls List[RequestOutput]
- from the MQAphroditeEngine via zmq (each list is the output
- of one engine_step in the AphroditeEngine). It then parses
- the list and pushes individual request_outputs into
- the corresponding output_queue such that they can be
- consumed by the .generate() method.
- - health_loop: the health loop queries the health socket
- every N seconds, confirming the engine is healthy
- """
- def __init__(self, ipc_path: str, engine_config: EngineConfig):
- self.context = zmq.asyncio.Context()
- self._errored_with: Optional[BaseException] = None
- # Get the configs.
- self.model_config = engine_config.model_config
- self.decoding_config = engine_config.decoding_config
- # Create the tokenizer group.
- self.tokenizer = init_tokenizer_from_configs(
- model_config=self.model_config,
- scheduler_config=engine_config.scheduler_config,
- parallel_config=engine_config.parallel_config,
- enable_lora=bool(engine_config.lora_config),
- )
- # Send RPCGenerateRequest to the MQAphroditeEngine.
- self.input_socket: Socket = self.context.socket(zmq.constants.PUSH)
- self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}")
- # Receive streams of RequestOutput from the MQAphroditeEngine.
- self.output_socket: Socket = self.context.socket(zmq.constants.PULL)
- self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}")
- # IPC path for ack of check_health requests.
- self.health_socket: Socket = self.context.socket(zmq.constants.PULL)
- self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
- # IPC path for the data socket.
- self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
- # Stream for each individual request.
- self.output_queues: Dict[str, asyncio.Queue] = {}
- self.output_loop = asyncio.create_task(self.run_output_handler_loop())
- # Loop to check health of the AphroditeEngine periodically.
- # Started after the MQAphroditeEngine is ready.
- self.health_loop: Optional[asyncio.Task] = None
- @staticmethod
- def is_unsupported_config(engine_args: AsyncEngineArgs):
- # Pipeline parallel not yet supported
- return engine_args.pipeline_parallel_size > 1
- @contextmanager
- def get_data_socket(self) -> Iterator[Socket]:
- socket = self.context.socket(zmq.constants.DEALER)
- try:
- socket.connect(self.data_ipc_path)
- yield socket
- finally:
- socket.close(linger=0)
- async def run_check_health_loop(self, timeout: int):
- """Background loop that continually probes the RPCServer for health.
-
- The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which
- the MQAphroditeEngine server is blocking on.
- The Server replies on the HEALTH_SOCKET (rather than on the
- OUTPUT_SOCKET such that the messages are not intermingled with
- output streaming).
- """
- try:
- while True:
- if await self.health_socket.poll(timeout=timeout) == 0:
- # Wakeup every N seconds and do a health probe.
- await self._send_one_way_rpc_request(
- RPCHealthRequest(), self.input_socket)
- # Wait for ack from the health socket.
- await self._await_ack(error_message="Health check failed.",
- socket=self.health_socket)
- else:
- # Server sent a health status message unprompted.
- await self._check_success(
- error_message="Health check failed.",
- socket=self.health_socket)
- logger.debug("Health probe successful.")
- except asyncio.CancelledError:
- logger.debug(
- "Shutting down MQAphroditeEngineClient check health loop.")
- except Exception as e:
- self._set_errored(e)
- async def run_output_handler_loop(self):
- """Get RequestOutputs from Engine and stream to Request Queues"""
- try:
- while True:
- # Poll, checking for ENGINE_DEAD
- while await self.output_socket.poll(
- timeout=APHRODITE_RPC_TIMEOUT
- ) == 0:
- logger.debug("Waiting for output from MQAphroditeEngine.")
- # If errored, alert all running requests.
- if self.errored:
- for queue_j in tuple(self.output_queues.values()):
- queue_j.put_nowait(
- ENGINE_DEAD_ERROR(self._errored_with))
- return
- message: Frame = await self.output_socket.recv(copy=False)
- request_outputs = pickle.loads(message.buffer)
- is_error = isinstance(request_outputs,
- (BaseException, RPCError))
- if is_error:
- if isinstance(request_outputs, RPCError):
- rpc_error: RPCError = request_outputs
- request_id = rpc_error.request_id
- exception = rpc_error.exception
- is_engine_errored = rpc_error.is_engine_errored
- else:
- # MPAphroditeEngine should always return an RPCError to
- # the output_socket when an issue arises.
- # If we are here, we are in a bad state and
- # should shut down the server.
- error: BaseException = request_outputs
- logger.error(
- f"Received Exception {error} rather than RPCError "
- "from MPAphroditeEngine. This should never happen.")
- request_id = None
- exception = error
- is_engine_errored = True
- # Set to error state only on engine critical error
- # (and record only the first one)
- if is_engine_errored and not self._errored_with:
- self._errored_with = exception
- if request_id is None:
- for queue_i in tuple(self.output_queues.values()):
- queue_i.put_nowait(exception)
- else:
- queue = self.output_queues.get(request_id)
- if queue is not None:
- queue.put_nowait(exception)
- else:
- # Put each output into the appropriate steam.
- for request_output in request_outputs:
- queue = self.output_queues.get(
- request_output.request_id)
- if queue is not None:
- queue.put_nowait(request_output)
- except asyncio.CancelledError:
- logger.debug(
- "Shutting down MQAphroditeEngineClient output handler.")
- async def setup(self):
- """Setup the client before it starts sending server requests."""
- with self.get_data_socket() as socket:
- # Wait until server is ready.
- response = await self._wait_for_server_rpc(socket)
- self.tracing_flag = response.tracing_enabled
- # Start health_loop.
- self.health_loop = asyncio.create_task(
- self.run_check_health_loop(timeout=APHRODITE_RPC_TIMEOUT))
- def close(self):
- """Destroy the ZeroMQ Context."""
- # Close all sockets and terminate the context.
- self.context.destroy(linger=0)
- # Cancel background tasks.
- if self.health_loop is not None:
- self.health_loop.cancel()
- self.output_loop.cancel()
- def _set_errored(self, e: BaseException):
- logger.exception(repr(e))
- if self._errored_with is None:
- self._errored_with = e
- @staticmethod
- async def _send_get_data_rpc_request(request: RPCStartupRequest,
- expected_type: Any,
- error_message: str,
- socket: Socket) -> Any:
- """Send an RPC request that is expecting data back."""
- # Ping RPCServer with a request.
- await socket.send_multipart((pickle.dumps(request), ), copy=False)
- # Make sure the server responds in time.
- if await socket.poll(timeout=APHRODITE_RPC_TIMEOUT) == 0:
- raise TimeoutError("RPCServer didn't reply within "
- f"{APHRODITE_RPC_TIMEOUT} ms")
- # Await the data from the Server.
- frame = await socket.recv(copy=False)
- data = pickle.loads(frame.buffer)
- if isinstance(data, BaseException):
- raise data
- elif not isinstance(data, expected_type):
- raise ValueError(error_message)
- return data
- @staticmethod
- async def _send_one_way_rpc_request(request: RPC_REQUEST_T,
- socket: Socket):
- """Send one-way RPC request to trigger an action."""
- if socket.closed:
- raise MQClientClosedError()
- await socket.send_multipart((pickle.dumps(request), ))
- async def _await_ack(self, error_message: str, socket: Socket):
- """Await acknowledgement that a request succeeded."""
- if socket.closed:
- raise MQClientClosedError()
- if await socket.poll(timeout=APHRODITE_RPC_TIMEOUT) == 0:
- raise TimeoutError("MQAphroditeEngine didn't reply within "
- f"{APHRODITE_RPC_TIMEOUT}ms")
- await self._check_success(error_message, socket)
- @staticmethod
- async def _check_success(error_message: str, socket: Socket):
- """Confirm that socket has a APHRODITE_RPC_SUCCESS_STR message"""
- if socket.closed:
- raise MQClientClosedError()
- frame = await socket.recv(copy=False)
- response = pickle.loads(frame.buffer)
- # Raise error if unsuccessful
- if isinstance(response, BaseException):
- raise response
- elif (not isinstance(response, str)
- or response != APHRODITE_RPC_SUCCESS_STR):
- raise ValueError(error_message)
- async def get_tokenizer(self, lora_request: LoRARequest):
- return await self.tokenizer.get_lora_tokenizer_async(lora_request)
- async def get_decoding_config(self) -> DecodingConfig:
- return self.decoding_config
- async def get_model_config(self) -> ModelConfig:
- return self.model_config
- async def is_tracing_enabled(self) -> bool:
- return self.tracing_flag
- async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse:
- """Wait for the RPCServer to start up."""
- return await self._send_get_data_rpc_request(
- request=RPCStartupRequest.IS_SERVER_READY,
- expected_type=RPCStartupResponse,
- error_message="Unable to start RPC Server",
- socket=socket)
- async def abort(self, request_id: str):
- """Send an ABORT_REQUEST signal to the RPC Server"""
- with suppress(MQClientClosedError):
- await self._send_one_way_rpc_request(
- request=RPCAbortRequest(request_id), socket=self.input_socket)
- async def do_log_stats(self):
- """Ignore do_log_stats (handled on MQAphroditeEngine polling)"""
- pass
- async def check_health(self):
- """
- The check health loop probes the health status of the
- Engine's health every N seconds and sets _errored_with
- if the engine is unhealthy.
- """
- if self._errored_with is not None:
- raise self._errored_with
- @property
- def is_running(self) -> bool:
- return not self.errored
- @property
- def is_stopped(self) -> bool:
- return self.errored
- @property
- def errored(self) -> bool:
- return self._errored_with is not None
- @property
- def dead_error(self) -> BaseException:
- return ENGINE_DEAD_ERROR(self._errored_with)
- def generate(
- self,
- inputs: PromptInputs,
- sampling_params: SamplingParams,
- request_id: str,
- lora_request: Optional[LoRARequest] = None,
- prompt_adapter_request: Optional[PromptAdapterRequest] = None
- ) -> AsyncGenerator[RequestOutput, None]:
- """Generate outputs for a request.
- Generate outputs for a request. This method is a coroutine. It adds the
- request into the waiting queue of the AphroditeEngine and streams the
- outputs from the AphroditeEngine to the caller.
- Args:
- inputs: The inputs to the LLM. See
- :class:`~aphrodite.inputs.PromptInputs`
- for more details about the format of each input.
- sampling_params: The sampling parameters of the request.
- request_id: The unique id of the request.
- lora_request: LoRA request to use for generation, if any.
- prompt_adapter_request: Prompt Adapter request to use
- for generation, if any.
- """
- return self._process_request(inputs, sampling_params, request_id,
- lora_request, prompt_adapter_request)
- def encode(
- self,
- inputs: PromptInputs,
- pooling_params: PoolingParams,
- request_id: str,
- lora_request: Optional[LoRARequest] = None,
- ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
- """Generate outputs for a request from an embedding model.
- Generate outputs for a request. This method is a coroutine. It adds the
- request into the waiting queue of the AphroditeEngine and streams the
- outputs from the AphroditeEngine to the caller.
- Args:
- inputs: The inputs to the LLM. See
- :class:`~aphrodite.inputs.PromptInputs`
- for more details about the format of each input.
- pooling_params: The pooling parameters of the request.
- request_id: The unique id of the request.
- lora_request: LoRA request to use for generation, if any.
- Yields:
- The output `EmbeddingRequestOutput` objects from the AphroditeEngine
- for the request.
- """
- return self._process_request(inputs, pooling_params, request_id,
- lora_request)
- async def _process_request(
- self,
- inputs: PromptInputs,
- params: Union[SamplingParams, PoolingParams],
- request_id: str,
- lora_request: Optional[LoRARequest] = None,
- prompt_adapter_request: Optional[PromptAdapterRequest] = None
- ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
- EmbeddingRequestOutput, None]]:
- """Send an RPCGenerateRequest to the RPCServer and stream responses."""
- # If already dead, error out.
- if self._errored_with is not None:
- raise ENGINE_DEAD_ERROR(self._errored_with)
- # 1) Create output queue for this requests.
- queue: asyncio.Queue[Union[RequestOutput,
- BaseException]] = asyncio.Queue()
- self.output_queues[request_id] = queue
- try:
- # 2) Detach logits processors so that they can be pickled
- # separately (may require cloudpickle which is slower)
- if isinstance(params, SamplingParams) and params.logits_processors:
- # Defensive shallow copy
- params = copy.copy(params)
- logits_processors = params.logits_processors
- params.logits_processors = None
- lp_bytes = cloudpickle.dumps(logits_processors)
- else:
- lp_bytes = None
- request_bytes = pickle.dumps(
- RPCProcessRequest(
- inputs=inputs,
- params=params,
- request_id=request_id,
- lora_request=lora_request,
- prompt_adapter_request=prompt_adapter_request))
- # 3) Send the RPCGenerateRequest to the MQAphroditeEngine.
- parts = (request_bytes,
- lp_bytes) if lp_bytes else (request_bytes, )
- await self.input_socket.send_multipart(parts, copy=False)
- # 4) Stream the RequestOutputs from the output queue. Note
- # that the output_loop pushes RequestOutput objects to this
- # queue after pulling them from the zmq socket.
- finished = False
- try:
- while not finished:
- request_output = await queue.get()
- if isinstance(request_output, BaseException):
- raise request_output
- finished = request_output.finished
- yield request_output
- finally:
- # Request was canceled by the client.
- if not finished and not self.errored:
- await self.abort(request_id)
- finally:
- self.output_queues.pop(request_id)
|