server.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import asyncio
  2. import signal
  3. from typing import Any, Coroutine
  4. import cloudpickle
  5. import uvloop
  6. import zmq
  7. import zmq.asyncio
  8. from loguru import logger
  9. from typing_extensions import Never
  10. from aphrodite import AsyncAphrodite, AsyncEngineArgs
  11. from aphrodite.endpoints.openai.rpc import (APHRODITE_RPC_HEALTHY_STR,
  12. APHRODITE_RPC_SUCCESS_STR,
  13. RPCAbortRequest,
  14. RPCGenerateRequest,
  15. RPCUtilityRequest)
  16. class AsyncEngineRPCServer:
  17. def __init__(self, async_engine_args: AsyncEngineArgs, rpc_path: str):
  18. # Initialize engine first.
  19. self.engine = AsyncAphrodite.from_engine_args(async_engine_args)
  20. # Initialize context.
  21. self.context = zmq.asyncio.Context()
  22. # Init socket for readiness state.
  23. self.socket = self.context.socket(zmq.constants.ROUTER)
  24. self.socket.bind(rpc_path)
  25. def cleanup(self):
  26. """Cleanup all resources."""
  27. self.socket.close()
  28. self.context.destroy()
  29. self.engine.shutdown_background_loop()
  30. # Clear the engine reference so that it can be GC'ed.
  31. self.engine = None
  32. async def get_model_config(self, identity):
  33. """Send the ModelConfig"""
  34. model_config = await self.engine.get_model_config()
  35. await self.socket.send_multipart(
  36. [identity, cloudpickle.dumps(model_config)])
  37. async def get_decoding_config(self, identity):
  38. """Send the DecodingConfig"""
  39. decoding_config = await self.engine.get_decoding_config()
  40. await self.socket.send_multipart(
  41. [identity, cloudpickle.dumps(decoding_config)])
  42. async def get_lora_config(self, identity):
  43. lora_config = await self.engine.get_lora_config()
  44. await self.socket.send_multipart(
  45. [identity, cloudpickle.dumps(lora_config)])
  46. async def get_scheduler_config(self, identity):
  47. """Send the SchedulerConfig"""
  48. parallel_config = await self.engine.get_scheduler_config()
  49. await self.socket.send_multipart(
  50. [identity, cloudpickle.dumps(parallel_config)])
  51. async def get_parallel_config(self, identity):
  52. """Send the ParallelConfig"""
  53. parallel_config = await self.engine.get_parallel_config()
  54. await self.socket.send_multipart(
  55. [identity, cloudpickle.dumps(parallel_config)])
  56. async def do_log_stats(self, identity):
  57. """Log stats and confirm success."""
  58. await self.engine.do_log_stats()
  59. await self.socket.send_multipart([
  60. identity,
  61. cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR),
  62. ])
  63. async def is_server_ready(self, identity):
  64. """Notify the client that we are ready."""
  65. await self.socket.send_multipart([
  66. identity,
  67. cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR),
  68. ])
  69. async def abort(self, identity, request: RPCAbortRequest):
  70. """Abort request and notify the client of success."""
  71. try:
  72. # Abort the request in the llm engine.
  73. await self.engine.abort(request.request_id)
  74. except Exception:
  75. logger.warning(f"Failed to abort request {request.request_id}")
  76. finally:
  77. # Send confirmation to the client.
  78. await self.socket.send_multipart([
  79. identity,
  80. cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR),
  81. ])
  82. async def generate(self, identity, generate_request: RPCGenerateRequest):
  83. try:
  84. results_generator = self.engine.generate(
  85. generate_request.inputs,
  86. sampling_params=generate_request.sampling_params,
  87. request_id=generate_request.request_id,
  88. lora_request=generate_request.lora_request,
  89. prompt_adapter_request=generate_request.prompt_adapter_request)
  90. async for request_output in results_generator:
  91. await self.socket.send_multipart(
  92. [identity, cloudpickle.dumps(request_output)])
  93. except Exception as e:
  94. ### Notify client of all failures
  95. await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
  96. async def check_health(self, identity):
  97. try:
  98. await self.engine.check_health()
  99. await self.socket.send_multipart(
  100. [identity,
  101. cloudpickle.dumps(APHRODITE_RPC_HEALTHY_STR)])
  102. except Exception as e:
  103. await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
  104. def _make_handler_coro(self, identity,
  105. message) -> Coroutine[Any, Any, Never]:
  106. """Route the zmq message to the handler coroutine."""
  107. request = cloudpickle.loads(message)
  108. if isinstance(request, RPCGenerateRequest):
  109. return self.generate(identity, request)
  110. elif isinstance(request, RPCAbortRequest):
  111. return self.abort(identity, request)
  112. elif isinstance(request, RPCUtilityRequest):
  113. if request == RPCUtilityRequest.GET_MODEL_CONFIG:
  114. return self.get_model_config(identity)
  115. elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
  116. return self.get_parallel_config(identity)
  117. elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
  118. return self.get_decoding_config(identity)
  119. elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
  120. return self.get_scheduler_config(identity)
  121. elif request == RPCUtilityRequest.GET_LORA_CONFIG:
  122. return self.get_lora_config(identity)
  123. elif request == RPCUtilityRequest.DO_LOG_STATS:
  124. return self.do_log_stats(identity)
  125. elif request == RPCUtilityRequest.IS_SERVER_READY:
  126. return self.is_server_ready(identity)
  127. elif request == RPCUtilityRequest.CHECK_HEALTH:
  128. return self.check_health(identity)
  129. else:
  130. raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
  131. else:
  132. raise ValueError(f"Unknown RPCRequest type: {request}")
  133. async def run_server_loop(self):
  134. """Inner RPC Server Loop"""
  135. running_tasks = set()
  136. while True:
  137. # Wait for a request.
  138. identity, message = await self.socket.recv_multipart()
  139. # Process the request async.
  140. task = asyncio.create_task(
  141. self._make_handler_coro(identity, message))
  142. # We need to keep around a strong reference to the task,
  143. # to avoid the task disappearing mid-execution as running tasks
  144. # can be GC'ed. Below is a common "fire-and-forget" tasks
  145. # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
  146. running_tasks.add(task)
  147. task.add_done_callback(running_tasks.discard)
  148. async def run_server(server: AsyncEngineRPCServer):
  149. # Put the server task into the asyncio loop.
  150. loop = asyncio.get_running_loop()
  151. server_task = loop.create_task(server.run_server_loop())
  152. # Interruption handling.
  153. def signal_handler() -> None:
  154. # Kill the server on interrupt / terminate
  155. server_task.cancel()
  156. loop.add_signal_handler(signal.SIGINT, signal_handler)
  157. loop.add_signal_handler(signal.SIGTERM, signal_handler)
  158. try:
  159. await server_task
  160. except asyncio.CancelledError:
  161. logger.info("Aphrodite ZMQ RPC Server was interrupted.")
  162. finally:
  163. # Clean up all resources.
  164. server.cleanup()
  165. def run_rpc_server(async_engine_args: AsyncEngineArgs, rpc_path: str):
  166. server = AsyncEngineRPCServer(async_engine_args, rpc_path)
  167. uvloop.run(run_server(server))