client.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. import asyncio
  2. from contextlib import contextmanager, suppress
  3. from typing import Any, AsyncGenerator, Optional
  4. from uuid import uuid4
  5. import cloudpickle
  6. import zmq
  7. import zmq.asyncio
  8. from loguru import logger
  9. from aphrodite.common.config import (DecodingConfig, LoRAConfig, ModelConfig,
  10. ParallelConfig, SchedulerConfig)
  11. from aphrodite.common.envs import APHRODITE_RPC_GET_DATA_TIMEOUT_MS
  12. from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
  13. from aphrodite.common.sampling_params import SamplingParams
  14. from aphrodite.endpoints.openai.rpc import (APHRODITE_RPC_SOCKET_LIMIT_CUTOFF,
  15. APHRODITE_RPC_SUCCESS_STR,
  16. APHRODITE_RPC_ZMQ_HWM,
  17. RPC_REQUEST_TYPE, RPCAbortRequest,
  18. RPCGenerateRequest,
  19. RPCUtilityRequest)
  20. from aphrodite.inputs import PromptInputs
  21. from aphrodite.lora.request import LoRARequest
  22. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  23. from aphrodite.transformers_utils.tokenizer_group import (
  24. init_tokenizer_from_configs)
  25. # Path used for inprocess proxy.
  26. INPROC_PROXY_PATH = f"inproc://{uuid4()}"
  27. class RPCClientClosedError(Exception):
  28. """Exception class raised when the client is used post-close.
  29. The client can be closed, which closes the ZMQ context. This normally
  30. happens on server shutdown. In some cases, methods like abort and
  31. do_log_stats will still be called and then try to open a socket, which
  32. causes a ZMQError and creates a huge stack trace.
  33. So, we throw this error such that we can suppress it.
  34. """
  35. class AsyncEngineRPCClient:
  36. """
  37. RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
  38. The overall design mirrors the Asynchronous Client Server Pattern
  39. https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern
  40. On startup, the RPCClient:
  41. - makes DEALER socket (to_rpc_server) that connects to the RPCServer
  42. via ipc, which uses unix sockets under the hood
  43. (https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html)
  44. - makes ROUTER socket (from_api_server) that binds to a random
  45. inproc address, which uses memory under the hood
  46. (https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html)
  47. - runs a proxy in a background asyncio task between
  48. from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, )
  49. Each request handled by the asyncio api_server calls generate():
  50. - make a DEALER socket that connects to from_api_server via inproc
  51. - send a RCPGenerateRequest to the inproc socket
  52. - background proxy forwards the request from inproc -> ipc
  53. - RPCServer responds to the request one token at a time over ipc
  54. - background proxy forwards the response from ipc -> inproc
  55. The connection looks like this:
  56. DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER
  57. Message routing is performed via identities that are managed by the
  58. ROUTER socket. ROUTER sockets track every connection it has and
  59. tells the caller about these. The way it tells the caller is to stick
  60. the connection identity in front of each message received. When we
  61. send the message via a ROUTER, we first send an identity frame.
  62. See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope
  63. for more details on connection identities.
  64. This proxy design enables us to use a single unix socket, which
  65. improves performance by avoiding syscalls (~5%) and avoids resource limits
  66. such as ulimit, which defaults to 1024 on ubuntu.
  67. Note: we run set_hwm(0) on each socket, which sets the HWM to inf,
  68. which is required to avoid dropping messages under high load.
  69. This is generally not advisable. However, since we are in control
  70. of both sides of the connection + failure on either side is
  71. catastrophic to the overall system health and memory profiling
  72. suggests limited memory overhead relative to asyncio, we will
  73. proceed for now.
  74. See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks
  75. for more details on high water marks.
  76. """
  77. def __init__(self, rpc_path: str):
  78. self.context = zmq.asyncio.Context()
  79. self._data_timeout = APHRODITE_RPC_GET_DATA_TIMEOUT_MS
  80. self._errored = False
  81. # Maximum number of sockets that can be opened (typically 65536).
  82. # ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
  83. socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT)
  84. if socket_limit < APHRODITE_RPC_SOCKET_LIMIT_CUTOFF:
  85. raise ValueError(
  86. f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps "
  87. "the number of concurrent requests Aphrodite can process. "
  88. "Launch Aphrodite with --disable-frontend-multiprocessing and "
  89. "open a GitHub issue so we can investigate.")
  90. # We only have 1 ipc connection that uses unix sockets, so
  91. # safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will
  92. # not run into ulimit issues)
  93. self.context.set(zmq.constants.MAX_SOCKETS, socket_limit)
  94. # IPC connection to RPC Server (uses unix sockets).
  95. self.to_rpc_server = self.context.socket(zmq.constants.DEALER)
  96. self.to_rpc_server.set_hwm(APHRODITE_RPC_ZMQ_HWM)
  97. self.to_rpc_server.bind(rpc_path)
  98. # In process proxy to RPC Server (uses memory-based messaging).
  99. self.from_api_server = self.context.socket(zmq.constants.ROUTER)
  100. self.from_api_server.set_hwm(APHRODITE_RPC_ZMQ_HWM)
  101. self.from_api_server.bind(INPROC_PROXY_PATH)
  102. # Asyncio background task for the proxy.
  103. self.proxy_task = asyncio.create_task(
  104. self.run_proxy(self.from_api_server, self.to_rpc_server))
  105. # Since we open 1 inproc socket per request, we have a hard cap on
  106. # the number of requests that can run in Aphrodite w. frontend
  107. # mulitprocessing. This value is used uvicorn to launch
  108. # with --limit-concurrency to return 503 when server is overloaded.
  109. # We need 2 sockets per request - 2:
  110. # 1 for generate(), 1 for abort(), do_log_stats(), check_health()
  111. self.limit_concurrency = socket_limit // 2 - 2
  112. async def run_proxy(self, socket_from, socket_to):
  113. """Background task that runs a proxy"""
  114. poller = zmq.asyncio.Poller()
  115. poller.register(socket_from, zmq.constants.POLLIN)
  116. poller.register(socket_to, zmq.constants.POLLIN)
  117. while True:
  118. events = await poller.poll()
  119. events = dict(events)
  120. if socket_from in events:
  121. identity, msg = await socket_from.recv_multipart()
  122. await socket_to.send_multipart([identity, msg])
  123. if socket_to in events:
  124. identity, msg = await socket_to.recv_multipart()
  125. await socket_from.send_multipart([identity, msg])
  126. async def setup(self):
  127. """Setup the client before it starts sending server requests."""
  128. # Wait until server is ready.
  129. await self._wait_for_server_rpc()
  130. # Get the configs.
  131. self.model_config = await self._get_model_config_rpc()
  132. self.decoding_config = await self._get_decoding_config_rpc()
  133. # Create the tokenizer group.
  134. # TODO: refactor OAI server to avoid needing this info.
  135. self.tokenizer = init_tokenizer_from_configs(
  136. model_config=self.model_config,
  137. scheduler_config=(await self._get_scheduler_config_rpc()),
  138. parallel_config=(await self._get_parallel_config_rpc()),
  139. enable_lora=bool(await self._get_lora_config_rpc()),
  140. )
  141. def close(self):
  142. """Destroy the ZeroMQ Context."""
  143. # Close all sockets associated with this context and
  144. # then terminate the context.
  145. self.from_api_server.close()
  146. self.to_rpc_server.close()
  147. self.context.destroy()
  148. @contextmanager
  149. def to_proxy_socket(self):
  150. # Connect to the RPCServer via the proxy.
  151. # Raise a sensible error if the client was already closed.
  152. # This can happen if a server shutdown is triggered but some coroutines
  153. # are still running requests.
  154. # There should not be a race condition with this check because we don't
  155. # yield to the event loop between here and opening the socket.
  156. if self.context.closed:
  157. raise RPCClientClosedError("The ZMQ client has already shut down")
  158. # Note that we use DEALER to enable asynchronous communication
  159. # to enable streaming.
  160. socket = self.context.socket(zmq.constants.DEALER)
  161. socket.set_hwm(APHRODITE_RPC_ZMQ_HWM)
  162. try:
  163. socket.connect(INPROC_PROXY_PATH)
  164. yield socket
  165. finally:
  166. socket.close(linger=0)
  167. async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
  168. expected_type: Any,
  169. error_message: str) -> Any:
  170. """Send an RPC request that is expecting data back."""
  171. with self.to_proxy_socket() as socket:
  172. # Ping RPCServer with a request.
  173. await socket.send_multipart([cloudpickle.dumps(request)])
  174. # Make sure the server responds
  175. if await socket.poll(timeout=self._data_timeout) == 0:
  176. raise TimeoutError("Server didn't reply within "
  177. f"{self._data_timeout} ms")
  178. # Await the data from the Server.
  179. data = cloudpickle.loads(await socket.recv())
  180. if isinstance(data, Exception):
  181. # Re-raise exceptions returned by the server
  182. raise data
  183. if not isinstance(data, expected_type):
  184. # LoRAConfig can be None.
  185. if expected_type == LoRAConfig and data is None:
  186. pass
  187. elif isinstance(data, Exception):
  188. logger.error(error_message)
  189. raise data
  190. else:
  191. raise ValueError(error_message)
  192. return data
  193. async def _send_one_way_rpc_request(
  194. self,
  195. request: RPC_REQUEST_TYPE,
  196. error_message: str,
  197. socket: Optional[zmq.asyncio.Socket] = None):
  198. """Send one-way RPC request to trigger an action."""
  199. async def do_rpc_call(socket: zmq.asyncio.Socket,
  200. request: RPC_REQUEST_TYPE):
  201. await socket.send_multipart([cloudpickle.dumps(request)])
  202. if await socket.poll(timeout=self._data_timeout) == 0:
  203. raise TimeoutError("Server didn't reply within "
  204. f"{self._data_timeout} ms")
  205. return cloudpickle.loads(await socket.recv())
  206. # Make a new socket connection.
  207. if socket is None:
  208. with self.to_proxy_socket() as socket:
  209. response = await do_rpc_call(socket, request)
  210. # Use existing socket connection.
  211. else:
  212. response = await do_rpc_call(socket, request)
  213. if not isinstance(
  214. response, str) or response != APHRODITE_RPC_SUCCESS_STR:
  215. if isinstance(response, Exception):
  216. logger.error(error_message)
  217. raise response
  218. raise ValueError(error_message)
  219. async def get_tokenizer(self, lora_request: LoRARequest):
  220. return await self.tokenizer.get_lora_tokenizer_async(lora_request)
  221. async def get_decoding_config(self) -> DecodingConfig:
  222. return self.decoding_config
  223. async def get_model_config(self) -> ModelConfig:
  224. return self.model_config
  225. async def _wait_for_server_rpc(self):
  226. """Wait for the RPCServer to start up."""
  227. await self._send_one_way_rpc_request(
  228. request=RPCUtilityRequest.IS_SERVER_READY,
  229. error_message="Unable to start RPC Server")
  230. async def _get_model_config_rpc(self) -> ModelConfig:
  231. """Get the ModelConfig object from the RPC Server"""
  232. return await self._send_get_data_rpc_request(
  233. RPCUtilityRequest.GET_MODEL_CONFIG,
  234. expected_type=ModelConfig,
  235. error_message="Could not get ModelConfig from RPC Server")
  236. async def _get_decoding_config_rpc(self) -> DecodingConfig:
  237. """Get DecodingConfig from the RPCServer"""
  238. return await self._send_get_data_rpc_request(
  239. RPCUtilityRequest.GET_DECODING_CONFIG,
  240. expected_type=DecodingConfig,
  241. error_message="Could not get DecodingConfig from RPC Server")
  242. async def _get_parallel_config_rpc(self) -> ParallelConfig:
  243. """Get ParallelConfig from the RPCServer"""
  244. return await self._send_get_data_rpc_request(
  245. RPCUtilityRequest.GET_PARALLEL_CONFIG,
  246. expected_type=ParallelConfig,
  247. error_message="Could not get ParallelConfig from RPC Server")
  248. async def _get_scheduler_config_rpc(self) -> SchedulerConfig:
  249. """Get SchedulerConfig from the RPCServer"""
  250. return await self._send_get_data_rpc_request(
  251. RPCUtilityRequest.GET_SCHEDULER_CONFIG,
  252. expected_type=SchedulerConfig,
  253. error_message="Could not get SchedulerConfig from RPC Server")
  254. async def _get_lora_config_rpc(self) -> LoRAConfig:
  255. """Get LoRAConfig from the RPCServer"""
  256. return await self._send_get_data_rpc_request(
  257. RPCUtilityRequest.GET_LORA_CONFIG,
  258. expected_type=LoRAConfig,
  259. error_message="Could not get LoRAConfig from RPC Server")
  260. async def abort(self, request_id: str):
  261. """Send an ABORT_REQUEST signal to the RPC Server"""
  262. # Suppress timeouts as well.
  263. # In cases where the server is busy processing requests and a very
  264. # large volume of abort requests arrive, it is likely that the server
  265. # will not be able to ack all of them in time. We have seen this when
  266. # we abort 20k requests at once while another 2k are processing- many
  267. # of them time out, but we see the server successfully abort all of the
  268. # requests.
  269. # In this case we assume that the server has received or will receive
  270. # these abort requests, and ignore the timeout. This prevents a massive
  271. # wall of `TimeoutError` stack traces.
  272. with suppress(RPCClientClosedError, TimeoutError):
  273. await self._send_one_way_rpc_request(
  274. request=RPCAbortRequest(request_id),
  275. error_message=f"RPCAbortRequest {request_id} failed")
  276. async def do_log_stats(self):
  277. """Send a DO_LOG_STATS signal to the RPC Server"""
  278. with suppress(RPCClientClosedError):
  279. await self._send_one_way_rpc_request(
  280. request=RPCUtilityRequest.DO_LOG_STATS,
  281. error_message="RPCRequest DO_LOG_STATS failed.")
  282. @property
  283. def is_running(self) -> bool:
  284. return not self._errored
  285. @property
  286. def is_stopped(self) -> bool:
  287. return self._errored
  288. @property
  289. def errored(self) -> bool:
  290. return self._errored
  291. async def generate(
  292. self,
  293. inputs: PromptInputs,
  294. sampling_params: SamplingParams,
  295. request_id: str,
  296. lora_request: Optional[LoRARequest] = None,
  297. prompt_adapter_request: Optional[PromptAdapterRequest] = None
  298. ) -> AsyncGenerator[RequestOutput, None]:
  299. """Send an RPCGenerateRequest to the RPCServer and stream responses."""
  300. finished = False
  301. try:
  302. with self.to_proxy_socket() as socket:
  303. # Send RPCGenerateRequest to the RPCServer.
  304. await socket.send_multipart([
  305. cloudpickle.dumps(
  306. RPCGenerateRequest(
  307. inputs=inputs,
  308. sampling_params=sampling_params,
  309. request_id=request_id,
  310. lora_request=lora_request,
  311. prompt_adapter_request=prompt_adapter_request))
  312. ])
  313. # Stream back the results from the RPC Server.
  314. while not finished:
  315. message = await socket.recv()
  316. request_output = cloudpickle.loads(message)
  317. if isinstance(request_output, Exception):
  318. # On exception, check if the server is still healthy
  319. # possibly setting the `errored` property.
  320. if not self._errored:
  321. try:
  322. await self.check_health(socket=socket)
  323. except Exception as e:
  324. self._errored = True
  325. logger.exception(repr(e))
  326. # NB: do before raising here so that the flag is set
  327. # by the time the caller receives this exception
  328. raise request_output
  329. finished = request_output.finished
  330. yield request_output
  331. finally:
  332. # Request was canceled by the client.
  333. if not finished and not self._errored:
  334. await self.abort(request_id)
  335. async def check_health(self,
  336. socket: Optional[zmq.asyncio.Socket] = None
  337. ) -> None:
  338. """Raise if unhealthy"""
  339. await self._send_one_way_rpc_request(
  340. request=RPCUtilityRequest.IS_SERVER_HEALTHY,
  341. error_message="Got Unhealthy response from RPC Server",
  342. socket=socket)
  343. async def encode(self, *args,
  344. **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
  345. raise NotImplementedError(
  346. "Embeddings not supported with multiprocessing backend")