server.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import asyncio
  2. import signal
  3. from typing import Any, Coroutine, Union
  4. import cloudpickle
  5. import zmq
  6. import zmq.asyncio
  7. from loguru import logger
  8. from typing_extensions import Never
  9. from aphrodite import AsyncAphrodite, AsyncEngineArgs
  10. from aphrodite.common.config import (DecodingConfig, LoRAConfig, ModelConfig,
  11. ParallelConfig, SchedulerConfig)
  12. from aphrodite.common.utils import in_windows
  13. from aphrodite.endpoints.openai.rpc import (APHRODITE_RPC_SUCCESS_STR,
  14. APHRODITE_RPC_ZMQ_HWM,
  15. RPCAbortRequest,
  16. RPCGenerateRequest,
  17. RPCUtilityRequest)
  18. if in_windows():
  19. import winloop as uvloop
  20. else:
  21. import uvloop
  22. CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
  23. SchedulerConfig, LoRAConfig]
  24. class AsyncEngineRPCServer:
  25. def __init__(self, async_engine_args: AsyncEngineArgs, rpc_path: str):
  26. # Initialize engine first.
  27. self.engine = AsyncAphrodite.from_engine_args(async_engine_args)
  28. # Initialize context.
  29. self.context = zmq.asyncio.Context()
  30. # Init socket.
  31. self.socket = self.context.socket(zmq.constants.DEALER)
  32. self.socket.set_hwm(APHRODITE_RPC_ZMQ_HWM)
  33. self.socket.connect(rpc_path)
  34. def cleanup(self):
  35. """Cleanup all resources."""
  36. self.socket.close()
  37. self.context.destroy()
  38. self.engine.shutdown_background_loop()
  39. # Clear the engine reference so that it can be GC'ed.
  40. self.engine = None
  41. async def get_config(self, identity, request):
  42. try:
  43. config: CONFIG_TYPE
  44. if request == RPCUtilityRequest.GET_MODEL_CONFIG:
  45. config = await self.engine.get_model_config()
  46. elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
  47. config = await self.engine.get_decoding_config()
  48. elif request == RPCUtilityRequest.GET_LORA_CONFIG:
  49. config = await self.engine.get_lora_config()
  50. elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
  51. config = await self.engine.get_scheduler_config()
  52. elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
  53. config = await self.engine.get_parallel_config()
  54. else:
  55. raise ValueError(f"Unknown Config Request: {request}")
  56. await self.socket.send_multipart(
  57. [identity, cloudpickle.dumps(config)])
  58. except Exception as e:
  59. await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
  60. async def do_log_stats(self, identity):
  61. """Log stats and confirm success."""
  62. await self.engine.do_log_stats()
  63. await self.socket.send_multipart(
  64. [identity, cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR)])
  65. async def is_server_ready(self, identity):
  66. """Notify the client that we are ready."""
  67. await self.socket.send_multipart(
  68. [identity, cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR)])
  69. async def abort(self, identity, request: RPCAbortRequest):
  70. """Abort request and notify the client of success."""
  71. try:
  72. # Abort the request in the llm engine.
  73. await self.engine.abort(request.request_id)
  74. result: Union[str, Exception] = APHRODITE_RPC_SUCCESS_STR
  75. except Exception as e:
  76. result = e
  77. await self.socket.send_multipart([identity, cloudpickle.dumps(result)])
  78. async def generate(self, identity, generate_request: RPCGenerateRequest):
  79. try:
  80. results_generator = self.engine.generate(
  81. generate_request.inputs,
  82. sampling_params=generate_request.sampling_params,
  83. request_id=generate_request.request_id,
  84. lora_request=generate_request.lora_request,
  85. prompt_adapter_request=generate_request.prompt_adapter_request)
  86. async for request_output in results_generator:
  87. await self.socket.send_multipart(
  88. [identity, cloudpickle.dumps(request_output)])
  89. except Exception as e:
  90. ### Notify client of all failures
  91. await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
  92. async def check_health(self, identity):
  93. try:
  94. await self.engine.check_health()
  95. await self.socket.send_multipart(
  96. [identity, cloudpickle.dumps(APHRODITE_RPC_SUCCESS_STR)])
  97. except Exception as e:
  98. await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
  99. def _make_handler_coro(self, identity,
  100. message) -> Coroutine[Any, Any, Never]:
  101. """Route the zmq message to the handler coroutine."""
  102. request = cloudpickle.loads(message)
  103. if isinstance(request, RPCGenerateRequest):
  104. return self.generate(identity, request)
  105. elif isinstance(request, RPCAbortRequest):
  106. return self.abort(identity, request)
  107. elif isinstance(request, RPCUtilityRequest):
  108. if request in [
  109. RPCUtilityRequest.GET_MODEL_CONFIG,
  110. RPCUtilityRequest.GET_PARALLEL_CONFIG,
  111. RPCUtilityRequest.GET_DECODING_CONFIG,
  112. RPCUtilityRequest.GET_SCHEDULER_CONFIG,
  113. RPCUtilityRequest.GET_LORA_CONFIG
  114. ]:
  115. return self.get_config(identity, request)
  116. elif request == RPCUtilityRequest.DO_LOG_STATS:
  117. return self.do_log_stats(identity)
  118. elif request == RPCUtilityRequest.IS_SERVER_READY:
  119. return self.is_server_ready(identity)
  120. elif request == RPCUtilityRequest.IS_SERVER_HEALTHY:
  121. return self.check_health(identity)
  122. else:
  123. raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
  124. else:
  125. raise ValueError(f"Unknown RPCRequest type: {request}")
  126. async def run_server_loop(self):
  127. """Inner RPC Server Loop"""
  128. running_tasks = set()
  129. while True:
  130. # Wait for a request.
  131. identity, message = await self.socket.recv_multipart()
  132. # Process the request async.
  133. task = asyncio.create_task(
  134. self._make_handler_coro(identity, message))
  135. # We need to keep around a strong reference to the task,
  136. # to avoid the task disappearing mid-execution as running tasks
  137. # can be GC'ed. Below is a common "fire-and-forget" tasks
  138. # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
  139. running_tasks.add(task)
  140. task.add_done_callback(running_tasks.discard)
  141. async def run_server(server: AsyncEngineRPCServer):
  142. # Put the server task into the asyncio loop.
  143. loop = asyncio.get_running_loop()
  144. server_task = loop.create_task(server.run_server_loop())
  145. # Interruption handling.
  146. def signal_handler() -> None:
  147. # Kill the server on interrupt / terminate
  148. server_task.cancel()
  149. loop.add_signal_handler(signal.SIGINT, signal_handler)
  150. loop.add_signal_handler(signal.SIGTERM, signal_handler)
  151. try:
  152. await server_task
  153. except asyncio.CancelledError:
  154. logger.info("Aphrodite ZMQ RPC Server was interrupted.")
  155. finally:
  156. # Clean up all resources.
  157. server.cleanup()
  158. def run_rpc_server(async_engine_args: AsyncEngineArgs, rpc_path: str):
  159. server = AsyncEngineRPCServer(async_engine_args, rpc_path)
  160. uvloop.run(run_server(server))