engine.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. import os
  2. import pickle
  3. import signal
  4. import sys
  5. from contextlib import contextmanager
  6. from typing import Iterator, List, Optional, Union
  7. import cloudpickle
  8. import zmq
  9. from loguru import logger
  10. from aphrodite import AphroditeEngine, AsyncEngineArgs, SamplingParams
  11. from aphrodite.common.config import (DecodingConfig, LoRAConfig, ModelConfig,
  12. ParallelConfig, SchedulerConfig)
  13. from aphrodite.common.outputs import RequestOutput
  14. from aphrodite.engine.multiprocessing import (APHRODITE_RPC_SUCCESS_STR,
  15. ENGINE_DEAD_ERROR, IPC_DATA_EXT,
  16. IPC_HEALTH_EXT, IPC_INPUT_EXT,
  17. IPC_OUTPUT_EXT,
  18. REQUEST_OUTPUTS_T,
  19. RPCAbortRequest, RPCError,
  20. RPCHealthRequest,
  21. RPCProcessRequest,
  22. RPCShutdownRequest,
  23. RPCStartupRequest,
  24. RPCStartupResponse)
  25. CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
  26. SchedulerConfig, LoRAConfig]
  27. POLLING_TIMEOUT_MS = 10000
  28. HEALTHY_RESPONSE = (pickle.dumps(APHRODITE_RPC_SUCCESS_STR), )
  29. class MQAphroditeEngine:
  30. """A multiprocessing wrapper for :class:`AphroditeEngine`.
  31. This class is used to wrap the :class:`AphroditeEngine` class to enable use
  32. in concurrnet manner. It runs a background loop and uses zeromq to
  33. receive new requests and stream outputs incrementally via ipc.
  34. The :class:`AphroditeEngine.generate` is kicked off when a new
  35. RPCGenerateRequest is received by the input_socket.
  36. The self.engine_loop checks the input_socket for new requests,
  37. adds them to the AphroditeEngine if there are any, calls the internal
  38. :class:`AphroditeEngine.step()`, and sends the RequestOutputs back over
  39. the output_socket.
  40. If use_async_sockets is set, the logic associated with reading new
  41. requests from the socket and sending data to the socket is passed
  42. as a callback to the llm_engine, which calls the logic asynchronously
  43. such that the IPC can be overlapped with the GPU.
  44. Args:
  45. ipc_path: Base path for zeromq interprocess messaging
  46. use_async_sockets: Whether to make send/recv async with GPU
  47. log_requests: Whether to log the requests.
  48. *args: Arguments for :class:`AphroditeEngine`.
  49. **kwargs: Arguments for :class:`AphroditeEngine`.
  50. """
  51. def __init__(self,
  52. ipc_path: str,
  53. use_async_sockets: bool,
  54. *args,
  55. log_requests: bool = True,
  56. **kwargs) -> None:
  57. self.engine = AphroditeEngine(*args, **kwargs)
  58. self.log_requests = log_requests
  59. self.use_async_sockets = use_async_sockets
  60. if self.use_async_sockets:
  61. self.engine.process_request_outputs_callback = \
  62. self._async_socket_engine_callback
  63. self.ctx = zmq.Context() # type: ignore[attr-defined]
  64. # Receive input from the client.
  65. self.input_socket = self.ctx.socket(zmq.constants.PULL)
  66. self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}")
  67. # Send output stream back to client.
  68. self.output_socket = self.ctx.socket(zmq.constants.PUSH)
  69. self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")
  70. # Send health status back to client.
  71. self.health_socket = self.ctx.socket(zmq.constants.PUSH)
  72. self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
  73. # IPC path for the data socket.
  74. self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
  75. # Error state.
  76. self._errored_with: Optional[BaseException] = None
  77. @property
  78. def dead_error(self) -> BaseException:
  79. if self._errored_with is not None:
  80. return ENGINE_DEAD_ERROR(self._errored_with)
  81. else:
  82. return ENGINE_DEAD_ERROR()
  83. @classmethod
  84. def from_engine_args(cls, engine_args: AsyncEngineArgs, ipc_path: str):
  85. """Creates an MQAphroditeEngine from the engine arguments."""
  86. engine_config = engine_args.create_engine_config()
  87. executor_class = AphroditeEngine._get_executor_cls(engine_config)
  88. return cls(
  89. ipc_path=ipc_path,
  90. use_async_sockets=engine_config.model_config.use_async_output_proc,
  91. **engine_config.to_dict(),
  92. executor_class=executor_class,
  93. log_requests=not engine_args.disable_log_requests,
  94. log_stats=not engine_args.disable_log_stats)
  95. def start(self):
  96. try:
  97. try:
  98. logger.debug("Starting Startup Loop.")
  99. self.run_startup_loop()
  100. logger.debug("Starting Engine Loop.")
  101. self.run_engine_loop()
  102. except Exception as e:
  103. logger.exception(repr(e))
  104. except KeyboardInterrupt:
  105. logger.debug("Shutting down MQAphroditeEngine.")
  106. finally:
  107. logger.debug("MQAphroditeEngine is shut down.")
  108. self.cleanup()
  109. def cleanup(self):
  110. """Cleanup zeromq state on shutdown."""
  111. # Closes all sockets and destroys context.
  112. self.ctx.destroy(linger=0)
  113. del self.engine
  114. @contextmanager
  115. def make_data_socket(
  116. self) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
  117. socket = self.ctx.socket(zmq.constants.ROUTER)
  118. try:
  119. socket.bind(self.data_ipc_path)
  120. yield socket
  121. finally:
  122. socket.close(linger=0)
  123. def run_startup_loop(self) -> None:
  124. """Startup loop for sending data from Engine -> Client."""
  125. with self.make_data_socket() as socket:
  126. response: Union[RPCStartupResponse, BaseException]
  127. try:
  128. identity, message = socket.recv_multipart(copy=False)
  129. request: RPCStartupRequest = pickle.loads(message.buffer)
  130. # Handle the query from the Client.
  131. if request == RPCStartupRequest.IS_SERVER_READY:
  132. response = RPCStartupResponse(
  133. tracing_enabled=False)
  134. except Exception as e:
  135. response = e
  136. socket.send_multipart((identity, pickle.dumps(response)),
  137. copy=False)
  138. def run_engine_loop(self):
  139. """Core busy loop of the AphroditeEngine."""
  140. while True:
  141. if not self.engine.has_unfinished_requests():
  142. # Poll until there is work to do.
  143. while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
  144. self.engine.do_log_stats()
  145. logger.debug("Waiting for new requests in engine loop.")
  146. # Handle any input from the client.
  147. self.handle_new_input()
  148. # Engine step.
  149. request_outputs = self.engine_step()
  150. # Send request outputs (if async, done in engine_step callback).
  151. if not self.use_async_sockets:
  152. self._send_outputs(request_outputs)
  153. def engine_step(self) -> List[RequestOutput]:
  154. """Engine step wrapper with error handling."""
  155. try:
  156. return self.engine.step()
  157. except SystemExit:
  158. raise
  159. except BaseException as e:
  160. self._set_errored(e)
  161. rpc_err = RPCError(request_id=None,
  162. is_engine_errored=True,
  163. exception=e)
  164. self._send_outputs(rpc_err)
  165. raise e
  166. def handle_new_input(self):
  167. """Handle new input from the socket"""
  168. try:
  169. while self.input_socket.poll(timeout=0) != 0:
  170. frames = self.input_socket.recv_multipart(copy=False)
  171. request = pickle.loads(frames[0].buffer)
  172. if isinstance(request, RPCProcessRequest):
  173. if len(frames) > 1:
  174. # Use cloudpickle for logits processors
  175. assert isinstance(request.params, SamplingParams)
  176. lprocs = cloudpickle.loads(frames[1].buffer)
  177. request.params.logits_processors = lprocs
  178. self._handle_process_request(request)
  179. elif isinstance(request, RPCAbortRequest):
  180. self._handle_abort_request(request)
  181. elif isinstance(request, RPCHealthRequest):
  182. self._handle_health_request()
  183. elif isinstance(request, RPCShutdownRequest):
  184. self.engine.shutdown()
  185. self._send_outputs(APHRODITE_RPC_SUCCESS_STR)
  186. break
  187. else:
  188. raise ValueError("Unknown RPCRequest Type: {request}")
  189. except Exception as e:
  190. self._set_errored(e)
  191. self._send_unhealthy(e)
  192. raise e
  193. def _handle_process_request(self, request: RPCProcessRequest):
  194. """Handle RPCProcessRequest by adding it to the AphroditeEngine."""
  195. request_id = request.request_id
  196. if self._errored_with is not None:
  197. rpc_err = RPCError(request_id=request_id,
  198. is_engine_errored=True,
  199. exception=ENGINE_DEAD_ERROR(self._errored_with))
  200. self._send_outputs(rpc_err)
  201. try:
  202. self.engine.add_request(
  203. request_id=request_id,
  204. prompt=request.prompt,
  205. params=request.params,
  206. lora_request=request.lora_request,
  207. prompt_adapter_request=request.prompt_adapter_request)
  208. if self.log_requests:
  209. logger.info(f"Added request {request.request_id}.")
  210. except Exception as e:
  211. # We do not set self._errored = True here, since the error
  212. # is due to an issue adding this request to the engine,
  213. # rather than an issue with the engine itself.
  214. is_errored = self._errored_with is not None
  215. rpc_err = RPCError(request_id=request_id,
  216. is_engine_errored=is_errored,
  217. exception=e)
  218. self._send_outputs(rpc_err)
  219. # Remove request from the engine.
  220. self.engine.abort_request(request_id)
  221. def _handle_abort_request(self, request: RPCAbortRequest):
  222. self.engine.abort_request(request.request_id)
  223. if self.log_requests:
  224. logger.info(f"Aborted request {request.request_id}.")
  225. def _handle_health_request(self):
  226. if self._errored_with is not None:
  227. self._send_unhealthy(self._errored_with)
  228. # Raises error if unhealthy.
  229. self.engine.check_health()
  230. self._send_healthy()
  231. def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
  232. """Send List of RequestOutput to RPCClient."""
  233. if outputs:
  234. output_bytes = pickle.dumps(outputs)
  235. self.output_socket.send_multipart((output_bytes, ), copy=False)
  236. def _send_healthy(self):
  237. """Send HEALTHY message to RPCClient."""
  238. self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
  239. def _send_unhealthy(self, error: BaseException):
  240. """Send UNHEALTHY message to RPCClient."""
  241. error_bytes = pickle.dumps(error)
  242. self.health_socket.send_multipart((error_bytes, ), copy=False)
  243. def _async_socket_engine_callback(self,
  244. request_outputs: REQUEST_OUTPUTS_T):
  245. """Callback used by engine to make socket handling async with GPU."""
  246. self._send_outputs(request_outputs)
  247. self.handle_new_input()
  248. def _set_errored(self, e: BaseException):
  249. """Log and set errored status if this is the first issue."""
  250. if self._errored_with is None:
  251. self._errored_with = e
  252. def run_mp_engine(engine_args: AsyncEngineArgs, ipc_path: str):
  253. def signal_handler(*_) -> None:
  254. with open(os.devnull, 'w') as devnull:
  255. sys.stderr = devnull
  256. raise KeyboardInterrupt("MQAphroditeEngine terminated")
  257. signal.signal(signal.SIGTERM, signal_handler)
  258. try:
  259. engine = MQAphroditeEngine.from_engine_args(engine_args=engine_args,
  260. ipc_path=ipc_path)
  261. engine.start()
  262. except KeyboardInterrupt as e:
  263. if str(e) == "MQAphroditeEngine terminated":
  264. pass
  265. else:
  266. raise