serving_chat.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. import codecs
  2. import time
  3. from typing import AsyncGenerator, AsyncIterator, List, Optional, Union
  4. from fastapi import Request
  5. from loguru import logger
  6. from aphrodite.common.outputs import RequestOutput
  7. from aphrodite.common.utils import random_uuid
  8. from aphrodite.endpoints.openai.protocol import (
  9. ChatCompletionRequest,
  10. ChatCompletionResponse,
  11. ChatCompletionResponseChoice,
  12. ChatCompletionResponseStreamChoice,
  13. ChatCompletionStreamResponse,
  14. ChatMessage,
  15. DeltaMessage,
  16. ErrorResponse,
  17. UsageInfo,
  18. )
  19. from aphrodite.endpoints.openai.serving_engine import LoRA, OpenAIServing
  20. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  21. from aphrodite.modeling.guided_decoding import (
  22. get_guided_decoding_logits_processor)
  23. class OpenAIServingChat(OpenAIServing):
  24. def __init__(self,
  25. engine: AsyncAphrodite,
  26. served_model_names: List[str],
  27. response_role: str,
  28. lora_modules: Optional[List[LoRA]] = None,
  29. chat_template=None):
  30. super().__init__(engine=engine,
  31. served_model_names=served_model_names,
  32. lora_modules=lora_modules)
  33. self.response_role = response_role
  34. self._load_chat_template(chat_template)
  35. async def create_chat_completion(
  36. self, request: ChatCompletionRequest, raw_request: Request
  37. ) -> Union[ErrorResponse, AsyncGenerator[str, None],
  38. ChatCompletionResponse]:
  39. """Completion API similar to OpenAI's API.
  40. See https://platform.openai.com/docs/api-reference/chat/create
  41. for the API specification. This API mimics the OpenAI
  42. ChatCompletion API.
  43. NOTE: Currently we do not support the following feature:
  44. - function_call (Users should implement this by themselves)
  45. """
  46. error_check_ret = await self._check_model(request)
  47. if error_check_ret is not None:
  48. return error_check_ret
  49. try:
  50. prompt = self.tokenizer.apply_chat_template(
  51. conversation=request.messages,
  52. tokenize=False,
  53. add_generation_prompt=request.add_generation_prompt)
  54. except Exception as e:
  55. logger.error(
  56. f"Error in applying chat template from request: {str(e)}")
  57. return self.create_error_response(str(e))
  58. request_id = f"cmpl-{random_uuid()}"
  59. try:
  60. # Tokenize/detokenize depending on prompt format (string/token list)
  61. prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
  62. request, prompt=prompt)
  63. sampling_params = request.to_sampling_params(
  64. self.tokenizer.vocab_size)
  65. lora_request = self._maybe_get_lora(request)
  66. guided_decode_logits_processor = (
  67. await get_guided_decoding_logits_processor(
  68. request.guided_decoding_backend, request, await
  69. self.engine.get_tokenizer()))
  70. if guided_decode_logits_processor:
  71. sampling_params.logits_processors.append(
  72. guided_decode_logits_processor)
  73. except ValueError as e:
  74. return self.create_error_response(str(e))
  75. result_generator = self.engine.generate(prompt_text, sampling_params,
  76. request_id, prompt_ids,
  77. lora_request)
  78. # Streaming response
  79. if request.stream:
  80. return self.chat_completion_stream_generator(
  81. request, result_generator, request_id)
  82. else:
  83. try:
  84. return await self.chat_completion_full_generator(
  85. request, raw_request, result_generator, request_id)
  86. except ValueError as e:
  87. # TODO: Use an aphrodite-specific Validation Error
  88. return self.create_error_response(str(e))
  89. def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
  90. if request.add_generation_prompt:
  91. return self.response_role
  92. else:
  93. return request.messages[-1]["role"]
  94. async def chat_completion_stream_generator(
  95. self, request: ChatCompletionRequest,
  96. result_generator: AsyncIterator[RequestOutput], request_id: str
  97. ) -> Union[ErrorResponse, AsyncGenerator[str, None]]:
  98. model_name = self.served_model_names[0]
  99. created_time = int(time.time())
  100. chunk_object_type = "chat.completion.chunk"
  101. first_iteration = True
  102. # Send response for each token for each request.n (index)
  103. previous_texts = [""] * request.n
  104. previous_num_tokens = [0] * request.n
  105. finish_reason_sent = [False] * request.n
  106. try:
  107. async for res in result_generator:
  108. res: RequestOutput
  109. # We need to do it here, because if there are exceptions in
  110. # the result_generator, it needs to be sent as the FIRST
  111. # response (by the try...catch).
  112. if first_iteration:
  113. # Send first response for each request.n (index) with
  114. # the role
  115. role = self.get_chat_request_role(request)
  116. for i in range(request.n):
  117. choice_data = ChatCompletionResponseStreamChoice(
  118. index=i,
  119. delta=DeltaMessage(role=role),
  120. logprobs=None,
  121. finish_reason=None)
  122. chunk = ChatCompletionStreamResponse(
  123. id=request_id,
  124. object=chunk_object_type,
  125. created=created_time,
  126. choices=[choice_data],
  127. model=model_name)
  128. data = chunk.model_dump_json(exclude_unset=True)
  129. yield f"data: {data}\n\n"
  130. # Send response to echo the input portion of the
  131. # last message
  132. if request.echo:
  133. last_msg_content = ""
  134. if request.messages and isinstance(
  135. request.messages,
  136. list) and request.messages[-1].get(
  137. "content") and request.messages[-1].get(
  138. "role") == role:
  139. last_msg_content = request.messages[-1]["content"]
  140. if last_msg_content:
  141. for i in range(request.n):
  142. choice_data = (
  143. ChatCompletionResponseStreamChoice(
  144. index=i,
  145. delta=DeltaMessage(
  146. content=last_msg_content),
  147. finish_reason=None))
  148. chunk = ChatCompletionStreamResponse(
  149. id=request_id,
  150. object=chunk_object_type,
  151. created=created_time,
  152. choices=[choice_data],
  153. logprobs=None,
  154. model=model_name)
  155. data = chunk.model_dump_json(
  156. exclude_unset=True)
  157. yield f"data: {data}\n\n"
  158. first_iteration = False
  159. for output in res.outputs:
  160. i = output.index
  161. if finish_reason_sent[i]:
  162. continue
  163. delta_token_ids = output.token_ids[previous_num_tokens[i]:]
  164. top_logprobs = output.logprobs[
  165. previous_num_tokens[i]:] if output.logprobs else None
  166. if request.logprobs:
  167. logprobs = self._create_logprobs(
  168. token_ids=delta_token_ids,
  169. top_logprobs=top_logprobs,
  170. num_output_top_logprobs=request.logprobs,
  171. initial_text_offset=len(previous_texts[i]),
  172. )
  173. else:
  174. logprobs = None
  175. delta_text = output.text[len(previous_texts[i]):]
  176. previous_texts[i] = output.text
  177. previous_num_tokens[i] = len(output.token_ids)
  178. if output.finish_reason is None:
  179. # Send token-by-token response for each request.n
  180. choice_data = ChatCompletionResponseStreamChoice(
  181. index=i,
  182. delta=DeltaMessage(content=delta_text),
  183. logprobs=logprobs,
  184. finish_reason=None)
  185. chunk = ChatCompletionStreamResponse(
  186. id=request_id,
  187. object=chunk_object_type,
  188. created=created_time,
  189. choices=[choice_data],
  190. model=model_name)
  191. data = chunk.model_dump_json(exclude_unset=True)
  192. yield f"data: {data}\n\n"
  193. else:
  194. # Send the finish response for each request.n only once
  195. prompt_tokens = len(res.prompt_token_ids)
  196. final_usage = UsageInfo(
  197. prompt_tokens=prompt_tokens,
  198. completion_tokens=previous_num_tokens[i],
  199. total_tokens=prompt_tokens +
  200. previous_num_tokens[i],
  201. )
  202. choice_data = ChatCompletionResponseStreamChoice(
  203. index=i,
  204. delta=DeltaMessage(content=delta_text),
  205. logprobs=logprobs,
  206. finish_reason=output.finish_reason,
  207. stop_reason=output.stop_reason)
  208. chunk = ChatCompletionStreamResponse(
  209. id=request_id,
  210. object=chunk_object_type,
  211. created=created_time,
  212. choices=[choice_data],
  213. model=model_name)
  214. if final_usage is not None:
  215. chunk.usage = final_usage
  216. data = chunk.model_dump_json(exclude_unset=True,
  217. exclude_none=True)
  218. yield f"data: {data}\n\n"
  219. finish_reason_sent[i] = True
  220. except ValueError as e:
  221. # TODO: Use an aphrodite-specific Validation Error
  222. data = self.create_streaming_error_response(str(e))
  223. yield f"data: {data}\n\n"
  224. # Send the final done message after all response.n are finished
  225. yield "data: [DONE]\n\n"
  226. async def chat_completion_full_generator(
  227. self, request: ChatCompletionRequest, raw_request: Request,
  228. result_generator: AsyncIterator[RequestOutput],
  229. request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
  230. model_name = self.served_model_names[0]
  231. created_time = int(time.time())
  232. final_res: RequestOutput = None
  233. async for res in result_generator:
  234. if await raw_request.is_disconnected():
  235. # Abort the request if the client disconnects.
  236. await self.engine.abort(request_id)
  237. return self.create_error_response("Client disconnected")
  238. final_res = res
  239. assert final_res is not None
  240. choices = []
  241. role = self.get_chat_request_role(request)
  242. for output in final_res.outputs:
  243. token_ids = output.token_ids
  244. top_logprobs = output.logprobs
  245. if request.logprobs:
  246. logprobs = self._create_logprobs(
  247. token_ids=token_ids,
  248. top_logprobs=top_logprobs,
  249. num_output_top_logprobs=request.logprobs,
  250. )
  251. else:
  252. logprobs = None
  253. choice_data = ChatCompletionResponseChoice(
  254. index=output.index,
  255. message=ChatMessage(role=role, content=output.text),
  256. logprobs=logprobs,
  257. finish_reason=output.finish_reason,
  258. stop_reason=output.stop_reason,
  259. )
  260. choices.append(choice_data)
  261. if request.echo:
  262. last_msg_content = ""
  263. if request.messages and isinstance(
  264. request.messages, list) and request.messages[-1].get(
  265. "content") and request.messages[-1].get(
  266. "role") == role:
  267. last_msg_content = request.messages[-1]["content"]
  268. for choice in choices:
  269. full_message = last_msg_content + choice.message.content
  270. choice.message.content = full_message
  271. num_prompt_tokens = len(final_res.prompt_token_ids)
  272. num_generated_tokens = sum(
  273. len(output.token_ids) for output in final_res.outputs)
  274. usage = UsageInfo(
  275. prompt_tokens=num_prompt_tokens,
  276. completion_tokens=num_generated_tokens,
  277. total_tokens=num_prompt_tokens + num_generated_tokens,
  278. )
  279. response = ChatCompletionResponse(
  280. id=request_id,
  281. created=created_time,
  282. model=model_name,
  283. choices=choices,
  284. usage=usage,
  285. )
  286. return response
  287. def _load_chat_template(self, chat_template):
  288. if chat_template is not None:
  289. try:
  290. with open(chat_template, "r") as f:
  291. self.tokenizer.chat_template = f.read()
  292. except OSError:
  293. # If opening a file fails, set chat template to be args to
  294. # ensure we decode so our escape are interpreted correctly
  295. self.tokenizer.chat_template = codecs.decode(
  296. chat_template, "unicode_escape")
  297. logger.info("Using the supplied chat template.")
  298. elif self.tokenizer.chat_template is not None:
  299. logger.info("Using the default chat template")
  300. else:
  301. logger.warning(
  302. "No chat template provided. Chat API will not work.")