serving_chat.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. import codecs
  2. import time
  3. from typing import AsyncGenerator, AsyncIterator, List, Optional, Union
  4. from fastapi import Request
  5. from loguru import logger
  6. from aphrodite.common.outputs import RequestOutput
  7. from aphrodite.common.utils import random_uuid
  8. from aphrodite.endpoints.openai.protocol import (
  9. ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
  10. ChatCompletionResponse, ChatCompletionResponseChoice,
  11. ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
  12. ChatMessage, DeltaMessage, ErrorResponse, FunctionCall, ToolCall,
  13. UsageInfo)
  14. from aphrodite.endpoints.openai.serving_engine import LoRA, OpenAIServing
  15. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  16. from aphrodite.modeling.guided_decoding import \
  17. get_guided_decoding_logits_processor
  18. class OpenAIServingChat(OpenAIServing):
  19. def __init__(self,
  20. engine: AsyncAphrodite,
  21. served_model_names: List[str],
  22. response_role: str,
  23. lora_modules: Optional[List[LoRA]] = None,
  24. chat_template=None):
  25. super().__init__(engine=engine,
  26. served_model_names=served_model_names,
  27. lora_modules=lora_modules)
  28. self.response_role = response_role
  29. self._load_chat_template(chat_template)
  30. async def create_chat_completion(
  31. self, request: ChatCompletionRequest, raw_request: Request
  32. ) -> Union[ErrorResponse, AsyncGenerator[str, None],
  33. ChatCompletionResponse]:
  34. """Completion API similar to OpenAI's API.
  35. See https://platform.openai.com/docs/api-reference/chat/create
  36. for the API specification. This API mimics the OpenAI
  37. ChatCompletion API.
  38. NOTE: Currently we do not support the following feature:
  39. - function_call (Users should implement this by themselves)
  40. """
  41. error_check_ret = await self._check_model(request)
  42. if error_check_ret is not None:
  43. return error_check_ret
  44. # Deal with list in messages.content
  45. # Just replace the content list with the very first text message
  46. for message in request.messages:
  47. if message["role"] == "user" and isinstance(
  48. message["content"], list):
  49. message["content"] = next((content["text"]
  50. for content in message["content"]
  51. if content["type"] == "text"), "")
  52. try:
  53. prompt = self.tokenizer.apply_chat_template(
  54. conversation=request.messages,
  55. tokenize=False,
  56. add_generation_prompt=request.add_generation_prompt)
  57. except Exception as e:
  58. logger.error(
  59. f"Error in applying chat template from request: {str(e)}")
  60. return self.create_error_response(str(e))
  61. request_id = f"cmpl-{random_uuid()}"
  62. try:
  63. # Tokenize/detokenize depending on prompt format (string/token list)
  64. prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
  65. request, prompt=prompt)
  66. sampling_params = request.to_sampling_params(
  67. self.tokenizer.vocab_size)
  68. lora_request = self._maybe_get_lora(request)
  69. guided_decode_logits_processor = (
  70. await get_guided_decoding_logits_processor(
  71. request.guided_decoding_backend, request, await
  72. self.engine.get_tokenizer()))
  73. if guided_decode_logits_processor:
  74. sampling_params.logits_processors.append(
  75. guided_decode_logits_processor)
  76. except ValueError as e:
  77. return self.create_error_response(str(e))
  78. result_generator = self.engine.generate(
  79. {
  80. "prompt": prompt_text,
  81. "prompt_token_ids": prompt_ids
  82. },
  83. sampling_params,
  84. request_id,
  85. lora_request,
  86. )
  87. # Streaming response
  88. if request.stream:
  89. return self.chat_completion_stream_generator(
  90. request, result_generator, request_id)
  91. else:
  92. try:
  93. return await self.chat_completion_full_generator(
  94. request, raw_request, result_generator, request_id)
  95. except ValueError as e:
  96. # TODO: Use an aphrodite-specific Validation Error
  97. return self.create_error_response(str(e))
  98. def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
  99. if request.add_generation_prompt:
  100. return self.response_role
  101. else:
  102. return request.messages[-1]["role"]
  103. async def chat_completion_stream_generator(
  104. self, request: ChatCompletionRequest,
  105. result_generator: AsyncIterator[RequestOutput], request_id: str
  106. ) -> Union[ErrorResponse, AsyncGenerator[str, None]]:
  107. model_name = self.served_model_names[0]
  108. created_time = int(time.time())
  109. chunk_object_type = "chat.completion.chunk"
  110. first_iteration = True
  111. # Send response for each token for each request.n (index)
  112. previous_texts = [""] * request.n
  113. previous_num_tokens = [0] * request.n
  114. finish_reason_sent = [False] * request.n
  115. try:
  116. async for res in result_generator:
  117. res: RequestOutput
  118. # We need to do it here, because if there are exceptions in
  119. # the result_generator, it needs to be sent as the FIRST
  120. # response (by the try...catch).
  121. if first_iteration:
  122. # Send first response for each request.n (index) with
  123. # the role
  124. role = self.get_chat_request_role(request)
  125. for i in range(request.n):
  126. choice_data = ChatCompletionResponseStreamChoice(
  127. index=i,
  128. delta=DeltaMessage(role=role),
  129. logprobs=None,
  130. finish_reason=None)
  131. chunk = ChatCompletionStreamResponse(
  132. id=request_id,
  133. object=chunk_object_type,
  134. created=created_time,
  135. choices=[choice_data],
  136. model=model_name)
  137. if (request.stream_options
  138. and request.stream_options.include_usage):
  139. chunk.usage = None
  140. data = chunk.model_dump_json(exclude_unset=True)
  141. yield f"data: {data}\n\n"
  142. # Send response to echo the input portion of the
  143. # last message
  144. if request.echo:
  145. last_msg_content = ""
  146. if request.messages and isinstance(
  147. request.messages,
  148. list) and request.messages[-1].get(
  149. "content") and request.messages[-1].get(
  150. "role") == role:
  151. last_msg_content = request.messages[-1]["content"]
  152. if last_msg_content:
  153. for i in range(request.n):
  154. choice_data = (
  155. ChatCompletionResponseStreamChoice(
  156. index=i,
  157. delta=DeltaMessage(
  158. content=last_msg_content),
  159. finish_reason=None))
  160. chunk = ChatCompletionStreamResponse(
  161. id=request_id,
  162. object=chunk_object_type,
  163. created=created_time,
  164. choices=[choice_data],
  165. logprobs=None,
  166. model=model_name)
  167. if (request.stream_options and
  168. request.stream_options.include_usage):
  169. chunk.usage = None
  170. data = chunk.model_dump_json(
  171. exclude_unset=True)
  172. yield f"data: {data}\n\n"
  173. first_iteration = False
  174. for output in res.outputs:
  175. i = output.index
  176. if finish_reason_sent[i]:
  177. continue
  178. delta_token_ids = output.token_ids[previous_num_tokens[i]:]
  179. top_logprobs = output.logprobs[
  180. previous_num_tokens[i]:] if output.logprobs else None
  181. if request.logprobs:
  182. logprobs = self._create_logprobs(
  183. token_ids=delta_token_ids,
  184. top_logprobs=top_logprobs,
  185. num_output_top_logprobs=request.logprobs,
  186. initial_text_offset=len(previous_texts[i]),
  187. )
  188. else:
  189. logprobs = None
  190. delta_text = output.text[len(previous_texts[i]):]
  191. previous_texts[i] = output.text
  192. previous_num_tokens[i] = len(output.token_ids)
  193. if request.tool_choice and type(
  194. request.tool_choice
  195. ) is ChatCompletionNamedToolChoiceParam:
  196. delta_message = DeltaMessage(tool_calls=[
  197. ToolCall(function=FunctionCall(
  198. name=request.tool_choice.function.name,
  199. arguments=delta_text))
  200. ])
  201. else:
  202. delta_message = DeltaMessage(content=delta_text)
  203. if output.finish_reason is None:
  204. # Send token-by-token response for each request.n
  205. choice_data = ChatCompletionResponseStreamChoice(
  206. index=i,
  207. delta=delta_message,
  208. logprobs=logprobs,
  209. finish_reason=None)
  210. chunk = ChatCompletionStreamResponse(
  211. id=request_id,
  212. object=chunk_object_type,
  213. created=created_time,
  214. choices=[choice_data],
  215. model=model_name)
  216. if (request.stream_options
  217. and request.stream_options.include_usage):
  218. chunk.usage = None
  219. data = chunk.model_dump_json(exclude_unset=True)
  220. yield f"data: {data}\n\n"
  221. else:
  222. # Send the finish response for each request.n only once
  223. prompt_tokens = len(res.prompt_token_ids)
  224. choice_data = ChatCompletionResponseStreamChoice(
  225. index=i,
  226. delta=delta_message,
  227. logprobs=logprobs,
  228. finish_reason=output.finish_reason,
  229. stop_reason=output.stop_reason)
  230. chunk = ChatCompletionStreamResponse(
  231. id=request_id,
  232. object=chunk_object_type,
  233. created=created_time,
  234. choices=[choice_data],
  235. model=model_name)
  236. if (request.stream_options
  237. and request.stream_options.include_usage):
  238. chunk.usage = None
  239. data = chunk.model_dump_json(exclude_unset=True)
  240. yield f"data: {data}\n\n"
  241. finish_reason_sent[i] = True
  242. if (request.stream_options
  243. and request.stream_options.include_usage):
  244. final_usage = UsageInfo(
  245. prompt_tokens=prompt_tokens,
  246. completion_tokens=previous_num_tokens[i],
  247. total_tokens=prompt_tokens +
  248. previous_num_tokens[i],
  249. )
  250. final_usage_chunk = ChatCompletionStreamResponse(
  251. id=request_id,
  252. object=chunk_object_type,
  253. created=created_time,
  254. choices=[],
  255. model=model_name,
  256. usage=final_usage)
  257. final_usage_data = (final_usage_chunk.model_dump_json(
  258. exclude_unset=True, exclude_none=True))
  259. yield f"data: {final_usage_data}\n\n"
  260. except ValueError as e:
  261. # TODO: Use an aphrodite-specific Validation Error
  262. data = self.create_streaming_error_response(str(e))
  263. yield f"data: {data}\n\n"
  264. # Send the final done message after all response.n are finished
  265. yield "data: [DONE]\n\n"
  266. async def chat_completion_full_generator(
  267. self, request: ChatCompletionRequest, raw_request: Request,
  268. result_generator: AsyncIterator[RequestOutput],
  269. request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
  270. model_name = self.served_model_names[0]
  271. created_time = int(time.time())
  272. final_res: RequestOutput = None
  273. async for res in result_generator:
  274. if await raw_request.is_disconnected():
  275. # Abort the request if the client disconnects.
  276. await self.engine.abort(request_id)
  277. return self.create_error_response("Client disconnected")
  278. final_res = res
  279. assert final_res is not None
  280. choices = []
  281. role = self.get_chat_request_role(request)
  282. for output in final_res.outputs:
  283. token_ids = output.token_ids
  284. top_logprobs = output.logprobs
  285. if request.logprobs:
  286. logprobs = self._create_logprobs(
  287. token_ids=token_ids,
  288. top_logprobs=top_logprobs,
  289. num_output_top_logprobs=request.logprobs,
  290. )
  291. else:
  292. logprobs = None
  293. if request.tool_choice and type(
  294. request.tool_choice) is ChatCompletionNamedToolChoiceParam:
  295. message = ChatMessage(
  296. role=role,
  297. content="",
  298. tool_calls=[
  299. ToolCall(function=FunctionCall(
  300. name=request.tool_choice.function.name,
  301. arguments=output.text))
  302. ])
  303. elif not request.tool_choice or request.tool_choice == "none":
  304. message = ChatMessage(role=role, content=output.text)
  305. choice_data = ChatCompletionResponseChoice(
  306. index=output.index,
  307. message=message,
  308. logprobs=logprobs,
  309. finish_reason=output.finish_reason,
  310. stop_reason=output.stop_reason,
  311. )
  312. choices.append(choice_data)
  313. if request.echo:
  314. last_msg_content = ""
  315. if request.messages and isinstance(
  316. request.messages, list) and request.messages[-1].get(
  317. "content") and request.messages[-1].get(
  318. "role") == role:
  319. last_msg_content = request.messages[-1]["content"]
  320. for choice in choices:
  321. full_message = last_msg_content + choice.message.content
  322. choice.message.content = full_message
  323. num_prompt_tokens = len(final_res.prompt_token_ids)
  324. num_generated_tokens = sum(
  325. len(output.token_ids) for output in final_res.outputs)
  326. usage = UsageInfo(
  327. prompt_tokens=num_prompt_tokens,
  328. completion_tokens=num_generated_tokens,
  329. total_tokens=num_prompt_tokens + num_generated_tokens,
  330. )
  331. response = ChatCompletionResponse(
  332. id=request_id,
  333. created=created_time,
  334. model=model_name,
  335. choices=choices,
  336. usage=usage,
  337. )
  338. return response
  339. def _load_chat_template(self, chat_template):
  340. if chat_template is not None:
  341. try:
  342. with open(chat_template, "r") as f:
  343. self.tokenizer.chat_template = f.read()
  344. except OSError:
  345. # If opening a file fails, set chat template to be args to
  346. # ensure we decode so our escape are interpreted correctly
  347. self.tokenizer.chat_template = codecs.decode(
  348. chat_template, "unicode_escape")
  349. logger.info("Using the supplied chat template.")
  350. elif self.tokenizer.chat_template is not None:
  351. logger.info("Using the default chat template")
  352. else:
  353. logger.warning(
  354. "No chat template provided. Chat API will not work.")