serving_chat.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. import time
  2. import codecs
  3. from fastapi import Request
  4. from typing import AsyncGenerator, AsyncIterator, Optional, List, Union
  5. from loguru import logger
  6. from aphrodite.common.utils import random_uuid
  7. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  8. from aphrodite.endpoints.openai.protocol import (
  9. ChatCompletionRequest, ChatCompletionResponse,
  10. ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
  11. ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
  12. UsageInfo)
  13. from aphrodite.common.outputs import RequestOutput
  14. from aphrodite.endpoints.openai.serving_engine import OpenAIServing, LoRA
  15. from aphrodite.modeling.outlines_decoding import (
  16. get_guided_decoding_logits_processor)
  17. class OpenAIServingChat(OpenAIServing):
  18. def __init__(self,
  19. engine: AsyncAphrodite,
  20. served_model: str,
  21. response_role: str,
  22. lora_modules: Optional[List[LoRA]] = None,
  23. chat_template=None):
  24. super().__init__(engine=engine,
  25. served_model=served_model,
  26. lora_modules=lora_modules)
  27. self.response_role = response_role
  28. self._load_chat_template(chat_template)
  29. async def create_chat_completion(
  30. self, request: ChatCompletionRequest, raw_request: Request
  31. ) -> Union[ErrorResponse, AsyncGenerator[str, None],
  32. ChatCompletionResponse]:
  33. """Completion API similar to OpenAI's API.
  34. See https://platform.openai.com/docs/api-reference/chat/create
  35. for the API specification. This API mimics the OpenAI ChatCompletion
  36. API.
  37. NOTE: Currently we do not support the following feature:
  38. - function_call (Users should implement this by themselves)
  39. """
  40. error_check_ret = await self._check_model(request)
  41. if error_check_ret is not None:
  42. return error_check_ret
  43. try:
  44. prompt = self.tokenizer.apply_chat_template(
  45. conversation=request.messages,
  46. tokenize=False,
  47. add_generation_prompt=request.add_generation_prompt)
  48. except Exception as e:
  49. logger.error(
  50. f"Error in applying chat template from request: {str(e)}")
  51. return self.create_error_response(str(e))
  52. request_id = f"cmpl-{random_uuid()}"
  53. try:
  54. token_ids = self._validate_prompt_and_tokenize(request,
  55. prompt=prompt)
  56. sampling_params = request.to_sampling_params(
  57. self.tokenizer.vocab_size)
  58. lora_request = self._maybe_get_lora(request)
  59. guided_decode_logits_processor = (
  60. await get_guided_decoding_logits_processor(
  61. request, self.engine.get_tokenizer()))
  62. if guided_decode_logits_processor:
  63. if sampling_params.logits_processors is None:
  64. sampling_params.logits_processors = []
  65. sampling_params.logits_processors.append(
  66. guided_decode_logits_processor)
  67. except ValueError as e:
  68. return self.create_error_response(str(e))
  69. result_generator = self.engine.generate(prompt, sampling_params,
  70. request_id, token_ids,
  71. lora_request)
  72. # Streaming response
  73. if request.stream:
  74. return self.chat_completion_stream_generator(
  75. request, result_generator, request_id)
  76. else:
  77. try:
  78. return await self.chat_completion_full_generator(
  79. request, raw_request, result_generator, request_id)
  80. except ValueError as e:
  81. # TODO: Use an aphrodite-specific Validation Error
  82. return self.create_error_response(str(e))
  83. def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
  84. if request.add_generation_prompt:
  85. return self.response_role
  86. else:
  87. return request.messages[-1]["role"]
  88. async def chat_completion_stream_generator(
  89. self, request: ChatCompletionRequest,
  90. result_generator: AsyncIterator[RequestOutput], request_id: str
  91. ) -> Union[ErrorResponse, AsyncGenerator[str, None]]:
  92. model_name = request.model
  93. created_time = int(time.monotonic())
  94. chunk_object_type = "chat.completion.chunk"
  95. first_iteration = True
  96. # Send response for each token for each request.n (index)
  97. previous_texts = [""] * request.n
  98. previous_num_tokens = [0] * request.n
  99. finish_reason_sent = [False] * request.n
  100. try:
  101. async for res in result_generator:
  102. res: RequestOutput
  103. # We need to do it here, because if there are exceptions in
  104. # the result_generator, it needs to be sent as the FIRST
  105. # response (by the try...catch).
  106. if first_iteration:
  107. # Send first response for each request.n (index) with
  108. # the role
  109. role = self.get_chat_request_role(request)
  110. for i in range(request.n):
  111. choice_data = ChatCompletionResponseStreamChoice(
  112. index=i,
  113. delta=DeltaMessage(role=role),
  114. logprobs=None,
  115. finish_reason=None)
  116. chunk = ChatCompletionStreamResponse(
  117. id=request_id,
  118. object=chunk_object_type,
  119. created=created_time,
  120. choices=[choice_data],
  121. model=model_name)
  122. data = chunk.model_dump_json(exclude_unset=True)
  123. yield f"data: {data}\n\n"
  124. # Send response to echo the input portion of the last
  125. # message
  126. if request.echo:
  127. last_msg_content = ""
  128. if request.messages and isinstance(
  129. request.messages,
  130. list) and request.messages[-1].get(
  131. "content") and request.messages[-1].get(
  132. "role") == role:
  133. last_msg_content = request.messages[-1]["content"]
  134. if last_msg_content:
  135. for i in range(request.n):
  136. choice_data = ChatCompletionResponseStreamChoice( # noqa
  137. index=i,
  138. delta=DeltaMessage(
  139. content=last_msg_content),
  140. finish_reason=None)
  141. chunk = ChatCompletionStreamResponse(
  142. id=request_id,
  143. object=chunk_object_type,
  144. created=created_time,
  145. choices=[choice_data],
  146. logprobs=None,
  147. model=model_name)
  148. data = chunk.model_dump_json(
  149. exclude_unset=True)
  150. yield f"data: {data}\n\n"
  151. first_iteration = False
  152. for output in res.outputs:
  153. i = output.index
  154. if finish_reason_sent[i]:
  155. continue
  156. delta_token_ids = output.token_ids[previous_num_tokens[i]:]
  157. top_logprobs = output.logprobs[
  158. previous_num_tokens[i]:] if output.logprobs else None
  159. if request.logprobs:
  160. logprobs = self._create_logprobs(
  161. token_ids=delta_token_ids,
  162. top_logprobs=top_logprobs,
  163. num_output_top_logprobs=request.logprobs,
  164. initial_text_offset=len(previous_texts[i]),
  165. )
  166. else:
  167. logprobs = None
  168. delta_text = output.text[len(previous_texts[i]):]
  169. previous_texts[i] = output.text
  170. previous_num_tokens[i] = len(output.token_ids)
  171. if output.finish_reason is None:
  172. # Send token-by-token response for each request.n
  173. choice_data = ChatCompletionResponseStreamChoice(
  174. index=i,
  175. delta=DeltaMessage(content=delta_text),
  176. logprobs=logprobs,
  177. finish_reason=None)
  178. chunk = ChatCompletionStreamResponse(
  179. id=request_id,
  180. object=chunk_object_type,
  181. created=created_time,
  182. choices=[choice_data],
  183. model=model_name)
  184. data = chunk.model_dump_json(exclude_unset=True)
  185. yield f"data: {data}\n\n"
  186. else:
  187. # Send the finish response for each request.n only once
  188. prompt_tokens = len(res.prompt_token_ids)
  189. final_usage = UsageInfo(
  190. prompt_tokens=prompt_tokens,
  191. completion_tokens=previous_num_tokens[i],
  192. total_tokens=prompt_tokens +
  193. previous_num_tokens[i],
  194. )
  195. choice_data = ChatCompletionResponseStreamChoice(
  196. index=i,
  197. delta=DeltaMessage(content=delta_text),
  198. logprobs=logprobs,
  199. finish_reason=output.finish_reason)
  200. chunk = ChatCompletionStreamResponse(
  201. id=request_id,
  202. object=chunk_object_type,
  203. created=created_time,
  204. choices=[choice_data],
  205. model=model_name)
  206. if final_usage is not None:
  207. chunk.usage = final_usage
  208. data = chunk.model_dump_json(exclude_unset=True,
  209. exclude_none=True)
  210. yield f"data: {data}\n\n"
  211. finish_reason_sent[i] = True
  212. except ValueError as e:
  213. # TODO: Use an aphrodite-specific Validation Error
  214. data = self.create_streaming_error_response(str(e))
  215. yield f"data: {data}\n\n"
  216. # Send the final done message after all response.n are finished
  217. yield "data: [DONE]\n\n"
  218. async def chat_completion_full_generator(
  219. self, request: ChatCompletionRequest, raw_request: Request,
  220. result_generator: AsyncIterator[RequestOutput],
  221. request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
  222. model_name = request.model
  223. created_time = int(time.monotonic())
  224. final_res: RequestOutput = None
  225. async for res in result_generator:
  226. if await raw_request.is_disconnected():
  227. # Abort the request if the client disconnects.
  228. await self.engine.abort(request_id)
  229. return self.create_error_response("Client disconnected")
  230. final_res = res
  231. assert final_res is not None
  232. choices = []
  233. role = self.get_chat_request_role(request)
  234. for output in final_res.outputs:
  235. token_ids = output.token_ids
  236. top_logprobs = output.logprobs
  237. if request.logprobs:
  238. logprobs = self._create_logprobs(
  239. token_ids=token_ids,
  240. top_logprobs=top_logprobs,
  241. num_output_top_logprobs=request.logprobs,
  242. )
  243. else:
  244. logprobs = None
  245. choice_data = ChatCompletionResponseChoice(
  246. index=output.index,
  247. message=ChatMessage(role=role, content=output.text),
  248. logprobs=logprobs,
  249. finish_reason=output.finish_reason,
  250. )
  251. choices.append(choice_data)
  252. if request.echo:
  253. last_msg_content = ""
  254. if request.messages and isinstance(
  255. request.messages, list) and request.messages[-1].get(
  256. "content") and request.messages[-1].get(
  257. "role") == role:
  258. last_msg_content = request.messages[-1]["content"]
  259. for choice in choices:
  260. full_message = last_msg_content + choice.message.content
  261. choice.message.content = full_message
  262. num_prompt_tokens = len(final_res.prompt_token_ids)
  263. num_generated_tokens = sum(
  264. len(output.token_ids) for output in final_res.outputs)
  265. usage = UsageInfo(
  266. prompt_tokens=num_prompt_tokens,
  267. completion_tokens=num_generated_tokens,
  268. total_tokens=num_prompt_tokens + num_generated_tokens,
  269. )
  270. response = ChatCompletionResponse(
  271. id=request_id,
  272. created=created_time,
  273. model=model_name,
  274. choices=choices,
  275. usage=usage,
  276. )
  277. return response
  278. def _load_chat_template(self, chat_template):
  279. if chat_template is not None:
  280. try:
  281. with open(chat_template, "r") as f:
  282. self.tokenizer.chat_template = f.read()
  283. except OSError:
  284. # If opening a file fails, set chat template to be args to
  285. # ensure we decode so our escape are interpreted correctly
  286. self.tokenizer.chat_template = codecs.decode(
  287. chat_template, "unicode_escape")
  288. logger.info("Using the supplied chat template.")
  289. elif self.tokenizer.chat_template is not None:
  290. logger.info("Using the default chat template")
  291. else:
  292. logger.warning(
  293. "No chat template provided. Chat API will not work.")