server.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import asyncio
  2. import signal
  3. from typing import Any, Coroutine
  4. import cloudpickle
  5. import zmq
  6. import zmq.asyncio
  7. from loguru import logger
  8. from typing_extensions import Never
  9. from aphrodite import AsyncAphrodite, AsyncEngineArgs
  10. from aphrodite.common.utils import in_windows
  11. from aphrodite.endpoints.openai.rpc import (APHRODITE_RPC_HEALTHY_STR,
  12. APHRODITE_RPC_SUCCESS_STR,
  13. RPCAbortRequest,
  14. RPCGenerateRequest,
  15. RPCUtilityRequest)
  16. if in_windows():
  17. import winloop as uvloop
  18. else:
  19. import uvloop
  20. class AsyncEngineRPCServer:
  21. def __init__(self, async_engine_args: AsyncEngineArgs, rpc_path: str):
  22. # Initialize engine first.
  23. self.engine = AsyncAphrodite.from_engine_args(async_engine_args)
  24. # Initialize context.
  25. self.context = zmq.asyncio.Context()
  26. # Init socket for readiness state.
  27. self.socket = self.context.socket(zmq.constants.ROUTER)
  28. self.socket.bind(rpc_path)
  29. def cleanup(self):
  30. """Cleanup all resources."""
  31. self.socket.close()
  32. self.context.destroy()
  33. self.engine.shutdown_background_loop()
  34. # Clear the engine reference so that it can be GC'ed.
  35. self.engine = None
  36. async def get_model_config(self, identity):
  37. """Send the ModelConfig"""
  38. model_config = await self.engine.get_model_config()
  39. await self.socket.send_multipart(
  40. [identity, cloudpickle.dumps(model_config)])
  41. async def get_decoding_config(self, identity):
  42. """Send the DecodingConfig"""
  43. decoding_config = await self.engine.get_decoding_config()
  44. await self.socket.send_multipart(
  45. [identity, cloudpickle.dumps(decoding_config)])
  46. async def get_lora_config(self, identity):
  47. lora_config = await self.engine.get_lora_config()
  48. await self.socket.send_multipart(
  49. [identity, cloudpickle.dumps(lora_config)])
  50. async def get_scheduler_config(self, identity):
  51. """Send the SchedulerConfig"""
  52. parallel_config = await self.engine.get_scheduler_config()
  53. await self.socket.send_multipart(
  54. [identity, cloudpickle.dumps(parallel_config)])
  55. async def get_parallel_config(self, identity):
  56. """Send the ParallelConfig"""
  57. parallel_config = await self.engine.get_parallel_config()
  58. await self.socket.send_multipart(
  59. [identity, cloudpickle.dumps(parallel_config)])
  60. async def do_log_stats(self, identity):
  61. """Log stats and confirm success."""
  62. await self.engine.do_log_stats()
  63. await self.socket.send_multipart([
  64. identity,
  65. cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR),
  66. ])
  67. async def is_server_ready(self, identity):
  68. """Notify the client that we are ready."""
  69. await self.socket.send_multipart([
  70. identity,
  71. cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR),
  72. ])
  73. async def abort(self, identity, request: RPCAbortRequest):
  74. """Abort request and notify the client of success."""
  75. try:
  76. # Abort the request in the llm engine.
  77. await self.engine.abort(request.request_id)
  78. except Exception:
  79. logger.warning(f"Failed to abort request {request.request_id}")
  80. finally:
  81. # Send confirmation to the client.
  82. await self.socket.send_multipart([
  83. identity,
  84. cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR),
  85. ])
  86. async def generate(self, identity, generate_request: RPCGenerateRequest):
  87. try:
  88. results_generator = self.engine.generate(
  89. generate_request.inputs,
  90. sampling_params=generate_request.sampling_params,
  91. request_id=generate_request.request_id,
  92. lora_request=generate_request.lora_request,
  93. prompt_adapter_request=generate_request.prompt_adapter_request)
  94. async for request_output in results_generator:
  95. await self.socket.send_multipart(
  96. [identity, cloudpickle.dumps(request_output)])
  97. except Exception as e:
  98. ### Notify client of all failures
  99. await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
  100. async def check_health(self, identity):
  101. try:
  102. await self.engine.check_health()
  103. await self.socket.send_multipart(
  104. [identity,
  105. cloudpickle.dumps(APHRODITE_RPC_HEALTHY_STR)])
  106. except Exception as e:
  107. await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
  108. def _make_handler_coro(self, identity,
  109. message) -> Coroutine[Any, Any, Never]:
  110. """Route the zmq message to the handler coroutine."""
  111. request = cloudpickle.loads(message)
  112. if isinstance(request, RPCGenerateRequest):
  113. return self.generate(identity, request)
  114. elif isinstance(request, RPCAbortRequest):
  115. return self.abort(identity, request)
  116. elif isinstance(request, RPCUtilityRequest):
  117. if request == RPCUtilityRequest.GET_MODEL_CONFIG:
  118. return self.get_model_config(identity)
  119. elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
  120. return self.get_parallel_config(identity)
  121. elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
  122. return self.get_decoding_config(identity)
  123. elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
  124. return self.get_scheduler_config(identity)
  125. elif request == RPCUtilityRequest.GET_LORA_CONFIG:
  126. return self.get_lora_config(identity)
  127. elif request == RPCUtilityRequest.DO_LOG_STATS:
  128. return self.do_log_stats(identity)
  129. elif request == RPCUtilityRequest.IS_SERVER_READY:
  130. return self.is_server_ready(identity)
  131. elif request == RPCUtilityRequest.CHECK_HEALTH:
  132. return self.check_health(identity)
  133. else:
  134. raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
  135. else:
  136. raise ValueError(f"Unknown RPCRequest type: {request}")
  137. async def run_server_loop(self):
  138. """Inner RPC Server Loop"""
  139. running_tasks = set()
  140. while True:
  141. # Wait for a request.
  142. identity, message = await self.socket.recv_multipart()
  143. # Process the request async.
  144. task = asyncio.create_task(
  145. self._make_handler_coro(identity, message))
  146. # We need to keep around a strong reference to the task,
  147. # to avoid the task disappearing mid-execution as running tasks
  148. # can be GC'ed. Below is a common "fire-and-forget" tasks
  149. # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
  150. running_tasks.add(task)
  151. task.add_done_callback(running_tasks.discard)
  152. async def run_server(server: AsyncEngineRPCServer):
  153. # Put the server task into the asyncio loop.
  154. loop = asyncio.get_running_loop()
  155. server_task = loop.create_task(server.run_server_loop())
  156. # Interruption handling.
  157. def signal_handler() -> None:
  158. # Kill the server on interrupt / terminate
  159. server_task.cancel()
  160. loop.add_signal_handler(signal.SIGINT, signal_handler)
  161. loop.add_signal_handler(signal.SIGTERM, signal_handler)
  162. try:
  163. await server_task
  164. except asyncio.CancelledError:
  165. logger.info("Aphrodite ZMQ RPC Server was interrupted.")
  166. finally:
  167. # Clean up all resources.
  168. server.cleanup()
  169. def run_rpc_server(async_engine_args: AsyncEngineArgs, rpc_path: str):
  170. server = AsyncEngineRPCServer(async_engine_args, rpc_path)
  171. uvloop.run(run_server(server))