client.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. from contextlib import contextmanager
  2. from typing import Any, AsyncGenerator, Optional
  3. import cloudpickle
  4. import zmq
  5. import zmq.asyncio
  6. from aphrodite.common.config import (DecodingConfig, LoRAConfig, ModelConfig,
  7. ParallelConfig, SchedulerConfig)
  8. from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
  9. from aphrodite.common.sampling_params import SamplingParams
  10. from aphrodite.endpoints.openai.rpc import (APHRODITE_RPC_HEALTHY_STR,
  11. APHRODITE_RPC_SUCCESS_STR,
  12. RPC_REQUEST_TYPE, RPCAbortRequest,
  13. RPCGenerateRequest,
  14. RPCUtilityRequest)
  15. from aphrodite.inputs import PromptInputs
  16. from aphrodite.lora.request import LoRARequest
  17. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  18. from aphrodite.transformers_utils.tokenizer_group import (
  19. init_tokenizer_from_configs)
  20. # Time to wait before checking if the server process is alive
  21. SERVER_START_TIMEOUT_MS = 1000
  22. class AsyncEngineRPCClient:
  23. def __init__(self, rpc_path: str):
  24. self.context = zmq.asyncio.Context()
  25. self.rpc_path = rpc_path
  26. async def setup(self):
  27. """Setup the client before it starts sending server requests."""
  28. # Wait until server is ready.
  29. await self.wait_for_server()
  30. self._errored = False
  31. # Get the configs.
  32. self.model_config = await self._get_model_config_rpc()
  33. self.decoding_config = await self._get_decoding_config_rpc()
  34. # Create the tokenizer group.
  35. # TODO: refactor OAI server to avoid needing this info.
  36. self.tokenizer = init_tokenizer_from_configs(
  37. model_config=self.model_config,
  38. scheduler_config=(await self._get_scheduler_config_rpc()),
  39. parallel_config=(await self._get_parallel_config_rpc()),
  40. enable_lora=bool(await self._get_lora_config_rpc()),
  41. )
  42. def close(self):
  43. """Destroy the ZeroMQ Context."""
  44. self.context.destroy()
  45. @contextmanager
  46. def socket(self):
  47. # Ensure client sockets are always closed after use
  48. # Connect to RPC socket for Request-Reply pattern,
  49. # Note that we use DEALER to enable asynchronous communication
  50. # to enable streaming.
  51. socket = self.context.socket(zmq.constants.DEALER)
  52. try:
  53. socket.connect(self.rpc_path)
  54. yield socket
  55. finally:
  56. # linger == 0 means discard unsent messages
  57. # when the socket is closed. This is necessary
  58. # because otherwise self.context.destroy() will
  59. # wait for 30 seconds until unsent messages are
  60. # received, which is impossible if the server
  61. # crashed. In the absence of a server crash we
  62. # always expect a response before closing the
  63. # socket anyway.
  64. # Reference: http://api.zeromq.org/4-2:zmq-setsockopt#toc24
  65. socket.close(linger=0)
  66. async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
  67. expected_type: Any,
  68. error_message: str) -> Any:
  69. """Send an RPC request that is expecting data back."""
  70. with self.socket() as socket:
  71. # Ping RPCServer with a request.
  72. await socket.send(cloudpickle.dumps(request))
  73. # Await the data from the Server.
  74. data = cloudpickle.loads(await socket.recv())
  75. if not isinstance(data, expected_type):
  76. # LoRAConfig can be None.
  77. if expected_type == LoRAConfig and data is None:
  78. pass
  79. else:
  80. raise ValueError(error_message)
  81. return data
  82. async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
  83. error_message: str):
  84. """Send one-way RPC request to trigger an action."""
  85. with self.socket() as socket:
  86. # Ping RPC Server with request.
  87. await socket.send(cloudpickle.dumps(request))
  88. # Await acknowledgement from RPCServer.
  89. response = cloudpickle.loads(await socket.recv())
  90. if not isinstance(response, str) or response != \
  91. APHRODITE_RPC_SUCCESS_STR:
  92. raise ValueError(error_message)
  93. return response
  94. async def get_tokenizer(self, lora_request: LoRARequest):
  95. return await self.tokenizer.get_lora_tokenizer_async(lora_request)
  96. async def get_decoding_config(self) -> DecodingConfig:
  97. return self.decoding_config
  98. async def get_model_config(self) -> ModelConfig:
  99. return self.model_config
  100. async def wait_for_server(self):
  101. """Wait for the RPCServer to start up."""
  102. await self._send_one_way_rpc_request(
  103. request=RPCUtilityRequest.IS_SERVER_READY,
  104. error_message="Unable to start RPC Server.")
  105. async def _get_model_config_rpc(self) -> ModelConfig:
  106. """Get the ModelConfig object from the RPC Server"""
  107. return await self._send_get_data_rpc_request(
  108. RPCUtilityRequest.GET_MODEL_CONFIG,
  109. expected_type=ModelConfig,
  110. error_message="Could not get ModelConfig from RPC Server")
  111. async def _get_decoding_config_rpc(self) -> DecodingConfig:
  112. """Get DecodingConfig from the RPCServer"""
  113. return await self._send_get_data_rpc_request(
  114. RPCUtilityRequest.GET_DECODING_CONFIG,
  115. expected_type=DecodingConfig,
  116. error_message="Could not get DecodingConfig from RPC Server")
  117. async def _get_parallel_config_rpc(self) -> ParallelConfig:
  118. """Get ParallelConfig from the RPCServer"""
  119. return await self._send_get_data_rpc_request(
  120. RPCUtilityRequest.GET_PARALLEL_CONFIG,
  121. expected_type=ParallelConfig,
  122. error_message="Could not get ParallelConfig from RPC Server")
  123. async def _get_scheduler_config_rpc(self) -> SchedulerConfig:
  124. """Get SchedulerConfig from the RPCServer"""
  125. return await self._send_get_data_rpc_request(
  126. RPCUtilityRequest.GET_SCHEDULER_CONFIG,
  127. expected_type=SchedulerConfig,
  128. error_message="Could not get SchedulerConfig from RPC Server")
  129. async def _get_lora_config_rpc(self) -> LoRAConfig:
  130. """Get LoRAConfig from the RPCServer"""
  131. return await self._send_get_data_rpc_request(
  132. RPCUtilityRequest.GET_LORA_CONFIG,
  133. expected_type=LoRAConfig,
  134. error_message="Could not get LoRAConfig from RPC Server")
  135. async def abort(self, request_id: str):
  136. """Send an ABORT_REQUEST signal to the RPC Server"""
  137. await self._send_one_way_rpc_request(
  138. request=RPCAbortRequest(request_id),
  139. error_message=f"RPCAbortRequest {request_id} failed")
  140. async def do_log_stats(self):
  141. """Send a DO_LOG_STATS signal to the RPC Server"""
  142. await self._send_one_way_rpc_request(
  143. request=RPCUtilityRequest.DO_LOG_STATS,
  144. error_message="RPCRequest DO_LOG_STATS failed.")
  145. @property
  146. def is_running(self) -> bool:
  147. return not self._errored
  148. @property
  149. def is_stopped(self) -> bool:
  150. return self._errored
  151. @property
  152. def errored(self) -> bool:
  153. return self._errored
  154. async def generate(
  155. self,
  156. inputs: PromptInputs,
  157. sampling_params: SamplingParams,
  158. request_id: str,
  159. lora_request: Optional[LoRARequest] = None,
  160. prompt_adapter_request: Optional[PromptAdapterRequest] = None
  161. ) -> AsyncGenerator[RequestOutput, None]:
  162. """Send an RPCGenerateRequest to the RPCServer and stream responses."""
  163. finished = False
  164. try:
  165. with self.socket() as socket:
  166. # Send RPCGenerateRequest to the RPCServer.
  167. await socket.send_multipart([
  168. cloudpickle.dumps(
  169. RPCGenerateRequest(
  170. inputs=inputs,
  171. sampling_params=sampling_params,
  172. request_id=request_id,
  173. lora_request=lora_request,
  174. prompt_adapter_request=prompt_adapter_request))
  175. ])
  176. # Stream back the results from the RPC Server.
  177. while not finished:
  178. message = await socket.recv()
  179. request_output = cloudpickle.loads(message)
  180. if isinstance(request_output, Exception):
  181. # On exception, check if the server is still healthy.
  182. # Use this to set the sync `is_running` and `errored`
  183. # properties.
  184. try:
  185. await self.check_health()
  186. except Exception:
  187. self._errored = True
  188. # NB: do before raising here so that the flag is set
  189. # by the time the caller receives this exception
  190. raise request_output
  191. finished = request_output.finished
  192. yield request_output
  193. finally:
  194. if not finished:
  195. await self.abort(request_id)
  196. async def check_health(self) -> None:
  197. """Raise if unhealthy"""
  198. with self.socket() as socket:
  199. # Ping RPCServer with CHECK_HEALTH request.
  200. await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH)
  201. )
  202. # Await the reply from the server.
  203. # TODO: do we need an internal timeout here?
  204. # Or do we expect the external probe to timeout and let this chill?
  205. health_message = cloudpickle.loads(await socket.recv())
  206. if isinstance(health_message, Exception):
  207. raise health_message
  208. if health_message != APHRODITE_RPC_HEALTHY_STR:
  209. raise ValueError("Expected healthy response from backend but got "
  210. f"{health_message}")
  211. async def encode(self, *args,
  212. **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
  213. raise NotImplementedError(
  214. "Embeddings not supported with multiprocessing backend")