server.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. import asyncio
  2. import os
  3. import signal
  4. from typing import Any, Coroutine, Union
  5. import cloudpickle
  6. import zmq
  7. import zmq.asyncio
  8. from loguru import logger
  9. from typing_extensions import Never
  10. from aphrodite import AsyncAphrodite, AsyncEngineArgs
  11. from aphrodite.common.config import (DecodingConfig, LoRAConfig, ModelConfig,
  12. ParallelConfig, SchedulerConfig)
  13. from aphrodite.common.utils import in_windows
  14. from aphrodite.endpoints.openai.rpc import (APHRODITE_RPC_SUCCESS_STR,
  15. APHRODITE_RPC_ZMQ_HWM,
  16. RPCAbortRequest,
  17. RPCGenerateRequest,
  18. RPCUtilityRequest)
  19. if in_windows():
  20. import winloop as uvloop
  21. else:
  22. import uvloop
  23. CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
  24. SchedulerConfig, LoRAConfig]
  25. class AsyncEngineRPCServer:
  26. def __init__(self, async_engine_args: AsyncEngineArgs, rpc_path: str):
  27. # Initialize engine first.
  28. self.engine = AsyncAphrodite.from_engine_args(async_engine_args)
  29. # Initialize context.
  30. self.context = zmq.asyncio.Context()
  31. # Init socket.
  32. self.socket = self.context.socket(zmq.constants.DEALER)
  33. self.socket.set_hwm(APHRODITE_RPC_ZMQ_HWM)
  34. self.socket.connect(rpc_path)
  35. def cleanup(self):
  36. """Cleanup all resources."""
  37. self.socket.close()
  38. self.context.destroy()
  39. self.engine.shutdown_background_loop()
  40. # Clear the engine reference so that it can be GC'ed.
  41. self.engine = None
  42. async def get_config(self, identity, request):
  43. try:
  44. config: CONFIG_TYPE
  45. if request == RPCUtilityRequest.GET_MODEL_CONFIG:
  46. config = await self.engine.get_model_config()
  47. elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
  48. config = await self.engine.get_decoding_config()
  49. elif request == RPCUtilityRequest.GET_LORA_CONFIG:
  50. config = await self.engine.get_lora_config()
  51. elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
  52. config = await self.engine.get_scheduler_config()
  53. elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
  54. config = await self.engine.get_parallel_config()
  55. else:
  56. raise ValueError(f"Unknown Config Request: {request}")
  57. await self.socket.send_multipart(
  58. [identity, cloudpickle.dumps(config)])
  59. except Exception as e:
  60. await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
  61. async def do_log_stats(self, identity):
  62. """Log stats and confirm success."""
  63. await self.engine.do_log_stats()
  64. await self.socket.send_multipart(
  65. [identity, cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR)])
  66. async def is_server_ready(self, identity):
  67. """Notify the client that we are ready."""
  68. await self.socket.send_multipart(
  69. [identity, cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR)])
  70. async def abort(self, identity, request: RPCAbortRequest):
  71. """Abort request and notify the client of success."""
  72. try:
  73. # Abort the request in the llm engine.
  74. await self.engine.abort(request.request_id)
  75. result: Union[str, Exception] = APHRODITE_RPC_SUCCESS_STR
  76. except Exception as e:
  77. result = e
  78. await self.socket.send_multipart([identity, cloudpickle.dumps(result)])
  79. async def generate(self, identity, generate_request: RPCGenerateRequest):
  80. try:
  81. results_generator = self.engine.generate(
  82. generate_request.inputs,
  83. sampling_params=generate_request.sampling_params,
  84. request_id=generate_request.request_id,
  85. lora_request=generate_request.lora_request,
  86. prompt_adapter_request=generate_request.prompt_adapter_request)
  87. async for request_output in results_generator:
  88. await self.socket.send_multipart(
  89. [identity, cloudpickle.dumps(request_output)])
  90. except Exception as e:
  91. ### Notify client of all failures
  92. await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
  93. async def check_health(self, identity):
  94. try:
  95. await self.engine.check_health()
  96. await self.socket.send_multipart(
  97. [identity, cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR)])
  98. except Exception as e:
  99. await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
  100. def _make_handler_coro(self, identity,
  101. message) -> Coroutine[Any, Any, Never]:
  102. """Route the zmq message to the handler coroutine."""
  103. request = cloudpickle.loads(message)
  104. if isinstance(request, RPCGenerateRequest):
  105. return self.generate(identity, request)
  106. elif isinstance(request, RPCAbortRequest):
  107. return self.abort(identity, request)
  108. elif isinstance(request, RPCUtilityRequest):
  109. if request in [
  110. RPCUtilityRequest.GET_MODEL_CONFIG,
  111. RPCUtilityRequest.GET_PARALLEL_CONFIG,
  112. RPCUtilityRequest.GET_DECODING_CONFIG,
  113. RPCUtilityRequest.GET_SCHEDULER_CONFIG,
  114. RPCUtilityRequest.GET_LORA_CONFIG
  115. ]:
  116. return self.get_config(identity, request)
  117. elif request == RPCUtilityRequest.DO_LOG_STATS:
  118. return self.do_log_stats(identity)
  119. elif request == RPCUtilityRequest.IS_SERVER_READY:
  120. return self.is_server_ready(identity)
  121. elif request == RPCUtilityRequest.IS_SERVER_HEALTHY:
  122. return self.check_health(identity)
  123. elif request == RPCUtilityRequest.SHUTDOWN_SERVER:
  124. return self.shutdown(identity)
  125. else:
  126. raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
  127. else:
  128. raise ValueError(f"Unknown RPCRequest type: {request}")
  129. async def run_server_loop(self):
  130. """Inner RPC Server Loop"""
  131. running_tasks = set()
  132. while True:
  133. # Wait for a request.
  134. identity, message = await self.socket.recv_multipart()
  135. # Process the request async.
  136. task = asyncio.create_task(
  137. self._make_handler_coro(identity, message))
  138. # We need to keep around a strong reference to the task,
  139. # to avoid the task disappearing mid-execution as running tasks
  140. # can be GC'ed. Below is a common "fire-and-forget" tasks
  141. # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
  142. running_tasks.add(task)
  143. task.add_done_callback(running_tasks.discard)
  144. async def shutdown(self, identity):
  145. """Handle shutdown request from client."""
  146. try:
  147. # Clean shutdown of engine
  148. self.engine.shutdown_background_loop()
  149. await self.socket.send_multipart(
  150. [identity, cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR)]
  151. )
  152. except Exception as e:
  153. await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
  154. finally:
  155. # Schedule server shutdown
  156. asyncio.create_task(self._delayed_shutdown())
  157. async def _delayed_shutdown(self):
  158. """Helper to shut down server after response is sent"""
  159. await asyncio.sleep(1)
  160. self.cleanup()
  161. # Force exit the process
  162. os._exit(0)
  163. async def run_server(server: AsyncEngineRPCServer):
  164. # Put the server task into the asyncio loop.
  165. loop = asyncio.get_running_loop()
  166. server_task = loop.create_task(server.run_server_loop())
  167. # Interruption handling.
  168. def signal_handler() -> None:
  169. # Kill the server on interrupt / terminate
  170. server_task.cancel()
  171. loop.add_signal_handler(signal.SIGINT, signal_handler)
  172. loop.add_signal_handler(signal.SIGTERM, signal_handler)
  173. try:
  174. await server_task
  175. except asyncio.CancelledError:
  176. logger.info("Aphrodite ZMQ RPC Server was interrupted.")
  177. finally:
  178. # Clean up all resources.
  179. server.cleanup()
  180. def run_rpc_server(async_engine_args: AsyncEngineArgs, rpc_path: str):
  181. server = AsyncEngineRPCServer(async_engine_args, rpc_path)
  182. uvloop.run(run_server(server))