client.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. import asyncio
  2. import copy
  3. import pickle
  4. from contextlib import contextmanager, suppress
  5. from typing import Any, AsyncGenerator, Dict, Iterator, Optional, Union
  6. import cloudpickle
  7. import zmq
  8. import zmq.asyncio
  9. from loguru import logger
  10. from zmq import Frame # type: ignore[attr-defined]
  11. from zmq.asyncio import Socket
  12. from aphrodite import PoolingParams
  13. from aphrodite.common.config import DecodingConfig, EngineConfig, ModelConfig
  14. from aphrodite.common.envs import APHRODITE_RPC_TIMEOUT
  15. from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
  16. from aphrodite.common.sampling_params import SamplingParams
  17. from aphrodite.engine.args_tools import AsyncEngineArgs
  18. from aphrodite.engine.multiprocessing import (APHRODITE_RPC_SUCCESS_STR,
  19. ENGINE_DEAD_ERROR, IPC_DATA_EXT,
  20. IPC_HEALTH_EXT, IPC_INPUT_EXT,
  21. IPC_OUTPUT_EXT, RPC_REQUEST_T,
  22. RPCAbortRequest, RPCError,
  23. RPCHealthRequest,
  24. RPCProcessRequest,
  25. RPCStartupRequest,
  26. RPCStartupResponse)
  27. from aphrodite.inputs import PromptInputs
  28. from aphrodite.lora.request import LoRARequest
  29. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  30. from aphrodite.transformers_utils.tokenizer_group import (
  31. init_tokenizer_from_configs)
  32. class MQClientClosedError(Exception):
  33. """Exception class raised when the client is used post-close.
  34. The client can be closed, which closes the ZMQ context. This normally
  35. happens on server shutdown. In some cases, methods like abort and
  36. do_log_stats will still be called and then try to open a socket, which
  37. causes a ZMQError and creates a huge stack trace.
  38. So, we throw this error such that we can suppress it.
  39. """
  40. class MQAphroditeEngineClient:
  41. """A client wrapper for MQAphroditeEngine that conforms to the
  42. EngineClient protocol.
  43. MQAphroditeEngine and MQAphroditeEngineClient are intended to run in
  44. separate processes communicating via zeromq ipc sockets.
  45. The entrypoint to MQAphroditeEngineClient is through the generate()
  46. method. On generate() MQAphroditeEngine does three things:
  47. - Creates an asyncio output queue
  48. - Sends a RPCGenerateRequest to the MQAphroditeEngine via zmq
  49. - Pulls RequestOutputs from its queue and yields them
  50. MQAphroditeEngine runs two background loops:
  51. - output_loop: the output loop pulls List[RequestOutput]
  52. from the MQAphroditeEngine via zmq (each list is the output
  53. of one engine_step in the AphroditeEngine). It then parses
  54. the list and pushes individual request_outputs into
  55. the corresponding output_queue such that they can be
  56. consumed by the .generate() method.
  57. - health_loop: the health loop queries the health socket
  58. every N seconds, confirming the engine is healthy
  59. """
  60. def __init__(self, ipc_path: str, engine_config: EngineConfig):
  61. self.context = zmq.asyncio.Context()
  62. self._errored_with: Optional[BaseException] = None
  63. # Get the configs.
  64. self.model_config = engine_config.model_config
  65. self.decoding_config = engine_config.decoding_config
  66. # Create the tokenizer group.
  67. self.tokenizer = init_tokenizer_from_configs(
  68. model_config=self.model_config,
  69. scheduler_config=engine_config.scheduler_config,
  70. parallel_config=engine_config.parallel_config,
  71. enable_lora=bool(engine_config.lora_config),
  72. )
  73. # Send RPCGenerateRequest to the MQAphroditeEngine.
  74. self.input_socket: Socket = self.context.socket(zmq.constants.PUSH)
  75. self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}")
  76. # Receive streams of RequestOutput from the MQAphroditeEngine.
  77. self.output_socket: Socket = self.context.socket(zmq.constants.PULL)
  78. self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}")
  79. # IPC path for ack of check_health requests.
  80. self.health_socket: Socket = self.context.socket(zmq.constants.PULL)
  81. self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
  82. # IPC path for the data socket.
  83. self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
  84. # Stream for each individual request.
  85. self.output_queues: Dict[str, asyncio.Queue] = {}
  86. self.output_loop = asyncio.create_task(self.run_output_handler_loop())
  87. # Loop to check health of the AphroditeEngine periodically.
  88. # Started after the MQAphroditeEngine is ready.
  89. self.health_loop: Optional[asyncio.Task] = None
  90. @staticmethod
  91. def is_unsupported_config(engine_args: AsyncEngineArgs):
  92. # Pipeline parallel not yet supported
  93. return engine_args.pipeline_parallel_size > 1
  94. @contextmanager
  95. def get_data_socket(self) -> Iterator[Socket]:
  96. socket = self.context.socket(zmq.constants.DEALER)
  97. try:
  98. socket.connect(self.data_ipc_path)
  99. yield socket
  100. finally:
  101. socket.close(linger=0)
  102. async def run_check_health_loop(self, timeout: int):
  103. """Background loop that continually probes the RPCServer for health.
  104. The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which
  105. the MQAphroditeEngine server is blocking on.
  106. The Server replies on the HEALTH_SOCKET (rather than on the
  107. OUTPUT_SOCKET such that the messages are not intermingled with
  108. output streaming).
  109. """
  110. try:
  111. while True:
  112. if await self.health_socket.poll(timeout=timeout) == 0:
  113. # Wakeup every N seconds and do a health probe.
  114. await self._send_one_way_rpc_request(
  115. RPCHealthRequest(), self.input_socket)
  116. # Wait for ack from the health socket.
  117. await self._await_ack(error_message="Health check failed.",
  118. socket=self.health_socket)
  119. else:
  120. # Server sent a health status message unprompted.
  121. await self._check_success(
  122. error_message="Health check failed.",
  123. socket=self.health_socket)
  124. logger.debug("Health probe successful.")
  125. except asyncio.CancelledError:
  126. logger.debug(
  127. "Shutting down MQAphroditeEngineClient check health loop.")
  128. except Exception as e:
  129. self._set_errored(e)
  130. async def run_output_handler_loop(self):
  131. """Get RequestOutputs from Engine and stream to Request Queues"""
  132. try:
  133. while True:
  134. # Poll, checking for ENGINE_DEAD
  135. while await self.output_socket.poll(
  136. timeout=APHRODITE_RPC_TIMEOUT
  137. ) == 0:
  138. logger.debug("Waiting for output from MQAphroditeEngine.")
  139. # If errored, alert all running requests.
  140. if self.errored:
  141. for queue_j in tuple(self.output_queues.values()):
  142. queue_j.put_nowait(
  143. ENGINE_DEAD_ERROR(self._errored_with))
  144. return
  145. message: Frame = await self.output_socket.recv(copy=False)
  146. request_outputs = pickle.loads(message.buffer)
  147. is_error = isinstance(request_outputs,
  148. (BaseException, RPCError))
  149. if is_error:
  150. if isinstance(request_outputs, RPCError):
  151. rpc_error: RPCError = request_outputs
  152. request_id = rpc_error.request_id
  153. exception = rpc_error.exception
  154. is_engine_errored = rpc_error.is_engine_errored
  155. else:
  156. # MPAphroditeEngine should always return an RPCError to
  157. # the output_socket when an issue arises.
  158. # If we are here, we are in a bad state and
  159. # should shut down the server.
  160. error: BaseException = request_outputs
  161. logger.error(
  162. f"Received Exception {error} rather than RPCError "
  163. "from MPAphroditeEngine. This should never happen.")
  164. request_id = None
  165. exception = error
  166. is_engine_errored = True
  167. # Set to error state only on engine critical error
  168. # (and record only the first one)
  169. if is_engine_errored and not self._errored_with:
  170. self._errored_with = exception
  171. if request_id is None:
  172. for queue_i in tuple(self.output_queues.values()):
  173. queue_i.put_nowait(exception)
  174. else:
  175. queue = self.output_queues.get(request_id)
  176. if queue is not None:
  177. queue.put_nowait(exception)
  178. else:
  179. # Put each output into the appropriate steam.
  180. for request_output in request_outputs:
  181. queue = self.output_queues.get(
  182. request_output.request_id)
  183. if queue is not None:
  184. queue.put_nowait(request_output)
  185. except asyncio.CancelledError:
  186. logger.debug(
  187. "Shutting down MQAphroditeEngineClient output handler.")
  188. async def setup(self):
  189. """Setup the client before it starts sending server requests."""
  190. with self.get_data_socket() as socket:
  191. # Wait until server is ready.
  192. response = await self._wait_for_server_rpc(socket)
  193. self.tracing_flag = response.tracing_enabled
  194. # Start health_loop.
  195. self.health_loop = asyncio.create_task(
  196. self.run_check_health_loop(timeout=APHRODITE_RPC_TIMEOUT))
  197. def close(self):
  198. """Destroy the ZeroMQ Context."""
  199. # Close all sockets and terminate the context.
  200. self.context.destroy(linger=0)
  201. # Cancel background tasks.
  202. if self.health_loop is not None:
  203. self.health_loop.cancel()
  204. self.output_loop.cancel()
  205. def _set_errored(self, e: BaseException):
  206. logger.exception(repr(e))
  207. if self._errored_with is None:
  208. self._errored_with = e
  209. @staticmethod
  210. async def _send_get_data_rpc_request(request: RPCStartupRequest,
  211. expected_type: Any,
  212. error_message: str,
  213. socket: Socket) -> Any:
  214. """Send an RPC request that is expecting data back."""
  215. # Ping RPCServer with a request.
  216. await socket.send_multipart((pickle.dumps(request), ), copy=False)
  217. # Make sure the server responds in time.
  218. if await socket.poll(timeout=APHRODITE_RPC_TIMEOUT) == 0:
  219. raise TimeoutError("RPCServer didn't reply within "
  220. f"{APHRODITE_RPC_TIMEOUT} ms")
  221. # Await the data from the Server.
  222. frame = await socket.recv(copy=False)
  223. data = pickle.loads(frame.buffer)
  224. if isinstance(data, BaseException):
  225. raise data
  226. elif not isinstance(data, expected_type):
  227. raise ValueError(error_message)
  228. return data
  229. @staticmethod
  230. async def _send_one_way_rpc_request(request: RPC_REQUEST_T,
  231. socket: Socket):
  232. """Send one-way RPC request to trigger an action."""
  233. if socket.closed:
  234. raise MQClientClosedError()
  235. await socket.send_multipart((pickle.dumps(request), ))
  236. async def _await_ack(self, error_message: str, socket: Socket):
  237. """Await acknowledgement that a request succeeded."""
  238. if socket.closed:
  239. raise MQClientClosedError()
  240. if await socket.poll(timeout=APHRODITE_RPC_TIMEOUT) == 0:
  241. raise TimeoutError("MQAphroditeEngine didn't reply within "
  242. f"{APHRODITE_RPC_TIMEOUT}ms")
  243. await self._check_success(error_message, socket)
  244. @staticmethod
  245. async def _check_success(error_message: str, socket: Socket):
  246. """Confirm that socket has a APHRODITE_RPC_SUCCESS_STR message"""
  247. if socket.closed:
  248. raise MQClientClosedError()
  249. frame = await socket.recv(copy=False)
  250. response = pickle.loads(frame.buffer)
  251. # Raise error if unsuccessful
  252. if isinstance(response, BaseException):
  253. raise response
  254. elif (not isinstance(response, str)
  255. or response != APHRODITE_RPC_SUCCESS_STR):
  256. raise ValueError(error_message)
  257. async def get_tokenizer(self, lora_request: LoRARequest):
  258. return await self.tokenizer.get_lora_tokenizer_async(lora_request)
  259. async def get_decoding_config(self) -> DecodingConfig:
  260. return self.decoding_config
  261. async def get_model_config(self) -> ModelConfig:
  262. return self.model_config
  263. async def is_tracing_enabled(self) -> bool:
  264. return self.tracing_flag
  265. async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse:
  266. """Wait for the RPCServer to start up."""
  267. return await self._send_get_data_rpc_request(
  268. request=RPCStartupRequest.IS_SERVER_READY,
  269. expected_type=RPCStartupResponse,
  270. error_message="Unable to start RPC Server",
  271. socket=socket)
  272. async def abort(self, request_id: str):
  273. """Send an ABORT_REQUEST signal to the RPC Server"""
  274. with suppress(MQClientClosedError):
  275. await self._send_one_way_rpc_request(
  276. request=RPCAbortRequest(request_id), socket=self.input_socket)
  277. async def do_log_stats(self):
  278. """Ignore do_log_stats (handled on MQAphroditeEngine polling)"""
  279. pass
  280. async def check_health(self):
  281. """
  282. The check health loop probes the health status of the
  283. Engine's health every N seconds and sets _errored_with
  284. if the engine is unhealthy.
  285. """
  286. if self._errored_with is not None:
  287. raise self._errored_with
  288. @property
  289. def is_running(self) -> bool:
  290. return not self.errored
  291. @property
  292. def is_stopped(self) -> bool:
  293. return self.errored
  294. @property
  295. def errored(self) -> bool:
  296. return self._errored_with is not None
  297. @property
  298. def dead_error(self) -> BaseException:
  299. return ENGINE_DEAD_ERROR(self._errored_with)
  300. def generate(
  301. self,
  302. inputs: PromptInputs,
  303. sampling_params: SamplingParams,
  304. request_id: str,
  305. lora_request: Optional[LoRARequest] = None,
  306. prompt_adapter_request: Optional[PromptAdapterRequest] = None
  307. ) -> AsyncGenerator[RequestOutput, None]:
  308. """Generate outputs for a request.
  309. Generate outputs for a request. This method is a coroutine. It adds the
  310. request into the waiting queue of the AphroditeEngine and streams the
  311. outputs from the AphroditeEngine to the caller.
  312. Args:
  313. inputs: The inputs to the LLM. See
  314. :class:`~aphrodite.inputs.PromptInputs`
  315. for more details about the format of each input.
  316. sampling_params: The sampling parameters of the request.
  317. request_id: The unique id of the request.
  318. lora_request: LoRA request to use for generation, if any.
  319. prompt_adapter_request: Prompt Adapter request to use
  320. for generation, if any.
  321. """
  322. return self._process_request(inputs, sampling_params, request_id,
  323. lora_request, prompt_adapter_request)
  324. def encode(
  325. self,
  326. inputs: PromptInputs,
  327. pooling_params: PoolingParams,
  328. request_id: str,
  329. lora_request: Optional[LoRARequest] = None,
  330. ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
  331. """Generate outputs for a request from an embedding model.
  332. Generate outputs for a request. This method is a coroutine. It adds the
  333. request into the waiting queue of the AphroditeEngine and streams the
  334. outputs from the AphroditeEngine to the caller.
  335. Args:
  336. inputs: The inputs to the LLM. See
  337. :class:`~aphrodite.inputs.PromptInputs`
  338. for more details about the format of each input.
  339. pooling_params: The pooling parameters of the request.
  340. request_id: The unique id of the request.
  341. lora_request: LoRA request to use for generation, if any.
  342. Yields:
  343. The output `EmbeddingRequestOutput` objects from the AphroditeEngine
  344. for the request.
  345. """
  346. return self._process_request(inputs, pooling_params, request_id,
  347. lora_request)
  348. async def _process_request(
  349. self,
  350. inputs: PromptInputs,
  351. params: Union[SamplingParams, PoolingParams],
  352. request_id: str,
  353. lora_request: Optional[LoRARequest] = None,
  354. prompt_adapter_request: Optional[PromptAdapterRequest] = None
  355. ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
  356. EmbeddingRequestOutput, None]]:
  357. """Send an RPCGenerateRequest to the RPCServer and stream responses."""
  358. # If already dead, error out.
  359. if self._errored_with is not None:
  360. raise ENGINE_DEAD_ERROR(self._errored_with)
  361. # 1) Create output queue for this requests.
  362. queue: asyncio.Queue[Union[RequestOutput,
  363. BaseException]] = asyncio.Queue()
  364. self.output_queues[request_id] = queue
  365. try:
  366. # 2) Detach logits processors so that they can be pickled
  367. # separately (may require cloudpickle which is slower)
  368. if isinstance(params, SamplingParams) and params.logits_processors:
  369. # Defensive shallow copy
  370. params = copy.copy(params)
  371. logits_processors = params.logits_processors
  372. params.logits_processors = None
  373. lp_bytes = cloudpickle.dumps(logits_processors)
  374. else:
  375. lp_bytes = None
  376. request_bytes = pickle.dumps(
  377. RPCProcessRequest(
  378. inputs=inputs,
  379. params=params,
  380. request_id=request_id,
  381. lora_request=lora_request,
  382. prompt_adapter_request=prompt_adapter_request))
  383. # 3) Send the RPCGenerateRequest to the MQAphroditeEngine.
  384. parts = (request_bytes,
  385. lp_bytes) if lp_bytes else (request_bytes, )
  386. await self.input_socket.send_multipart(parts, copy=False)
  387. # 4) Stream the RequestOutputs from the output queue. Note
  388. # that the output_loop pushes RequestOutput objects to this
  389. # queue after pulling them from the zmq socket.
  390. finished = False
  391. try:
  392. while not finished:
  393. request_output = await queue.get()
  394. if isinstance(request_output, BaseException):
  395. raise request_output
  396. finished = request_output.finished
  397. yield request_output
  398. finally:
  399. # Request was canceled by the client.
  400. if not finished and not self.errored:
  401. await self.abort(request_id)
  402. finally:
  403. self.output_queues.pop(request_id)