server.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. import asyncio
  2. import signal
  3. from typing import Any, Coroutine
  4. import cloudpickle
  5. import zmq
  6. import zmq.asyncio
  7. from loguru import logger
  8. from typing_extensions import Never
  9. from aphrodite import AsyncAphrodite, AsyncEngineArgs
  10. from aphrodite.endpoints.openai.rpc import (APHRODITE_RPC_HEALTHY_STR,
  11. APHRODITE_RPC_SUCCESS_STR,
  12. RPCAbortRequest,
  13. RPCGenerateRequest,
  14. RPCUtilityRequest)
  15. class AsyncEngineRPCServer:
  16. def __init__(self, async_engine_args: AsyncEngineArgs, port: int):
  17. # Initialize engine first.
  18. self.engine = AsyncAphrodite.from_engine_args(async_engine_args)
  19. # Initialize context.
  20. self.context = zmq.asyncio.Context()
  21. # Init socket for readiness state.
  22. self.socket = self.context.socket(zmq.constants.ROUTER)
  23. # NOTE numeric form of localhost should be used for zmq bind(),
  24. # see https://stackoverflow.com/a/8958414
  25. self.socket.bind(f"tcp://127.0.0.1:{port}")
  26. def cleanup(self):
  27. """Cleanup all resources."""
  28. self.socket.close()
  29. self.context.destroy()
  30. async def get_model_config(self, identity):
  31. """Send the ModelConfig"""
  32. model_config = await self.engine.get_model_config()
  33. await self.socket.send_multipart(
  34. [identity, cloudpickle.dumps(model_config)])
  35. async def get_decoding_config(self, identity):
  36. """Send the DecodingConfig"""
  37. decoding_config = await self.engine.get_decoding_config()
  38. await self.socket.send_multipart(
  39. [identity, cloudpickle.dumps(decoding_config)])
  40. async def get_lora_config(self, identity):
  41. lora_config = await self.engine.get_lora_config()
  42. await self.socket.send_multipart(
  43. [identity, cloudpickle.dumps(lora_config)])
  44. async def get_scheduler_config(self, identity):
  45. """Send the SchedulerConfig"""
  46. parallel_config = await self.engine.get_scheduler_config()
  47. await self.socket.send_multipart(
  48. [identity, cloudpickle.dumps(parallel_config)])
  49. async def get_parallel_config(self, identity):
  50. """Send the ParallelConfig"""
  51. parallel_config = await self.engine.get_parallel_config()
  52. await self.socket.send_multipart(
  53. [identity, cloudpickle.dumps(parallel_config)])
  54. async def do_log_stats(self, identity):
  55. """Log stats and confirm success."""
  56. await self.engine.do_log_stats()
  57. await self.socket.send_multipart([
  58. identity,
  59. cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR),
  60. ])
  61. async def is_server_ready(self, identity):
  62. """Notify the client that we are ready."""
  63. await self.socket.send_multipart([
  64. identity,
  65. cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR),
  66. ])
  67. async def abort(self, identity, request: RPCAbortRequest):
  68. """Abort request and notify the client of success."""
  69. # Abort the request in the llm engine.
  70. await self.engine.abort(request.request_id)
  71. # Send confirmation to the client.
  72. await self.socket.send_multipart([
  73. identity,
  74. cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR),
  75. ])
  76. async def generate(self, identity, generate_request: RPCGenerateRequest):
  77. try:
  78. results_generator = self.engine.generate(
  79. generate_request.inputs,
  80. sampling_params=generate_request.sampling_params,
  81. request_id=generate_request.request_id,
  82. lora_request=generate_request.lora_request,
  83. prompt_adapter_request=generate_request.prompt_adapter_request)
  84. async for request_output in results_generator:
  85. await self.socket.send_multipart(
  86. [identity, cloudpickle.dumps(request_output)])
  87. except Exception as e:
  88. ### Notify client of all failures
  89. await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
  90. async def check_health(self, identity):
  91. try:
  92. await self.engine.check_health()
  93. await self.socket.send_multipart(
  94. [identity,
  95. cloudpickle.dumps(APHRODITE_RPC_HEALTHY_STR)])
  96. except Exception as e:
  97. await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
  98. def _make_handler_coro(self, identity,
  99. message) -> Coroutine[Any, Any, Never]:
  100. """Route the zmq message to the handler coroutine."""
  101. request = cloudpickle.loads(message)
  102. if isinstance(request, RPCGenerateRequest):
  103. return self.generate(identity, request)
  104. elif isinstance(request, RPCAbortRequest):
  105. return self.abort(identity, request)
  106. elif isinstance(request, RPCUtilityRequest):
  107. if request == RPCUtilityRequest.GET_MODEL_CONFIG:
  108. return self.get_model_config(identity)
  109. elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
  110. return self.get_parallel_config(identity)
  111. elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
  112. return self.get_decoding_config(identity)
  113. elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
  114. return self.get_scheduler_config(identity)
  115. elif request == RPCUtilityRequest.GET_LORA_CONFIG:
  116. return self.get_lora_config(identity)
  117. elif request == RPCUtilityRequest.DO_LOG_STATS:
  118. return self.do_log_stats(identity)
  119. elif request == RPCUtilityRequest.IS_SERVER_READY:
  120. return self.is_server_ready(identity)
  121. elif request == RPCUtilityRequest.CHECK_HEALTH:
  122. return self.check_health(identity)
  123. else:
  124. raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
  125. else:
  126. raise ValueError(f"Unknown RPCRequest type: {request}")
  127. async def run_server_loop(self):
  128. """Inner RPC Server Loop"""
  129. running_tasks = set()
  130. while True:
  131. # Wait for a request.
  132. identity, message = await self.socket.recv_multipart()
  133. # Process the request async.
  134. task = asyncio.create_task(
  135. self._make_handler_coro(identity, message))
  136. # We need to keep around a strong reference to the task,
  137. # to avoid the task disappearing mid-execution as running tasks
  138. # can be GC'ed. Below is a common "fire-and-forget" tasks
  139. # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
  140. running_tasks.add(task)
  141. task.add_done_callback(running_tasks.discard)
  142. async def run_server(server: AsyncEngineRPCServer):
  143. # Put the server task into the asyncio loop.
  144. loop = asyncio.get_running_loop()
  145. server_task = loop.create_task(server.run_server_loop())
  146. # Interruption handling.
  147. def signal_handler() -> None:
  148. # Kill the server on interrupt / terminate
  149. server_task.cancel()
  150. loop.add_signal_handler(signal.SIGINT, signal_handler)
  151. loop.add_signal_handler(signal.SIGTERM, signal_handler)
  152. try:
  153. await server_task
  154. except asyncio.CancelledError:
  155. logger.info("Aphrodite ZMQ RPC Server was interrupted.")
  156. finally:
  157. # Clean up all resources.
  158. server.cleanup()
  159. def run_rpc_server(async_engine_args: AsyncEngineArgs, port: int):
  160. server = AsyncEngineRPCServer(async_engine_args, port)
  161. asyncio.run(run_server(server))