serving_chat.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. import asyncio
  2. import time
  3. from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
  4. from typing import Sequence as GenericSequence
  5. from typing import Union
  6. from fastapi import Request
  7. from loguru import logger
  8. from transformers import PreTrainedTokenizer
  9. from aphrodite.common.config import ModelConfig
  10. from aphrodite.common.outputs import RequestOutput
  11. from aphrodite.common.sequence import Logprob
  12. from aphrodite.common.utils import iterate_with_cancellation, random_uuid
  13. from aphrodite.endpoints.chat_utils import (ConversationMessage,
  14. apply_chat_template,
  15. load_chat_template,
  16. parse_chat_messages)
  17. from aphrodite.endpoints.logger import RequestLogger
  18. from aphrodite.endpoints.openai.protocol import (
  19. ChatCompletionLogProb, ChatCompletionLogProbs,
  20. ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
  21. ChatCompletionRequest, ChatCompletionResponse,
  22. ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
  23. ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
  24. FunctionCall, ToolCall, UsageInfo)
  25. from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
  26. OpenAIServing,
  27. PromptAdapterPath,
  28. TextTokensPrompt)
  29. from aphrodite.engine.protocol import AsyncEngineClient
  30. from aphrodite.inputs import PromptInputs
  31. from aphrodite.multimodal import MultiModalDataDict
  32. class OpenAIServingChat(OpenAIServing):
  33. def __init__(
  34. self,
  35. async_engine_client: AsyncEngineClient,
  36. model_config: ModelConfig,
  37. served_model_names: List[str],
  38. response_role: str,
  39. *,
  40. lora_modules: Optional[List[LoRAModulePath]],
  41. prompt_adapters: Optional[List[PromptAdapterPath]],
  42. request_logger: Optional[RequestLogger],
  43. chat_template: Optional[str],
  44. return_tokens_as_token_ids: bool = False,
  45. ):
  46. super().__init__(async_engine_client=async_engine_client,
  47. model_config=model_config,
  48. served_model_names=served_model_names,
  49. lora_modules=lora_modules,
  50. prompt_adapters=prompt_adapters,
  51. request_logger=request_logger,
  52. return_tokens_as_token_ids=return_tokens_as_token_ids)
  53. self.response_role = response_role
  54. # If this is None we use the tokenizer's default chat template
  55. self.chat_template = load_chat_template(chat_template)
  56. async def create_chat_completion(
  57. self,
  58. request: ChatCompletionRequest,
  59. raw_request: Optional[Request] = None
  60. ) -> Union[ErrorResponse, AsyncGenerator[str, None],
  61. ChatCompletionResponse]:
  62. """Completion API similar to OpenAI's API.
  63. See https://platform.openai.com/docs/api-reference/chat/create
  64. for the API specification. This API mimics the OpenAI
  65. ChatCompletion API.
  66. NOTE: Currently we do not support the following feature:
  67. - function_call (Users should implement this by themselves)
  68. """
  69. error_check_ret = await self._check_model(request)
  70. if error_check_ret is not None:
  71. return error_check_ret
  72. if request.prompt_logprobs is not None:
  73. if request.stream and request.prompt_logprobs > 0:
  74. return self.create_error_response(
  75. "Prompt_logprobs are not available when stream is enabled")
  76. if request.prompt_logprobs < 0:
  77. return self.create_error_response(
  78. f"Prompt_logprobs set to invalid "
  79. f"negative value: {request.prompt_logprobs}")
  80. try:
  81. (
  82. lora_request,
  83. prompt_adapter_request,
  84. ) = self._maybe_get_adapters(request)
  85. model_config = self.model_config
  86. tokenizer = await self.async_engine_client.get_tokenizer(
  87. lora_request)
  88. conversation, mm_futures = parse_chat_messages(
  89. request.messages, model_config, tokenizer)
  90. tool_dicts = None if request.tools is None else [
  91. tool.model_dump() for tool in request.tools
  92. ]
  93. prompt = apply_chat_template(
  94. tokenizer,
  95. conversation=conversation,
  96. chat_template=request.chat_template or self.chat_template,
  97. add_generation_prompt=request.add_generation_prompt,
  98. tools=tool_dicts,
  99. documents=request.documents,
  100. **(request.chat_template_kwargs or {}),
  101. )
  102. except Exception as e:
  103. logger.error(f"Error in applying chat template from request: {e}")
  104. return self.create_error_response(str(e))
  105. mm_data: Optional[MultiModalDataDict] = None
  106. try:
  107. if len(mm_futures):
  108. # since we support only single mm data currently
  109. assert len(
  110. mm_futures
  111. ) == 1, "Multiple 'image_url' input is currently not supported."
  112. mm_data = await mm_futures[0]
  113. except Exception as e:
  114. logger.error(f"Error in loading multi-modal data: {e}")
  115. return self.create_error_response(str(e))
  116. request_id = f"chat-{random_uuid()}"
  117. try:
  118. guided_decode_logits_processor = (
  119. await self._guided_decode_logits_processor(request, tokenizer))
  120. if isinstance(prompt, str):
  121. prompt_inputs = self._tokenize_prompt_input(
  122. request,
  123. tokenizer,
  124. prompt,
  125. truncate_prompt_tokens=request.truncate_prompt_tokens,
  126. add_special_tokens=request.add_special_tokens,
  127. )
  128. else:
  129. assert isinstance(prompt, list) and isinstance(
  130. prompt[0], int
  131. ), "Prompt has to be either a string or a list of token ids"
  132. prompt_inputs = TextTokensPrompt(
  133. prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)
  134. assert prompt_inputs is not None
  135. sampling_params = request.to_sampling_params(
  136. tokenizer,
  137. guided_decode_logits_processor,
  138. default_max_tokens=self.max_model_len -
  139. len(prompt_inputs["prompt_token_ids"]))
  140. self._log_inputs(request_id,
  141. prompt_inputs,
  142. params=sampling_params,
  143. lora_request=lora_request,
  144. prompt_adapter_request=prompt_adapter_request)
  145. engine_inputs: PromptInputs = {
  146. "prompt_token_ids": prompt_inputs["prompt_token_ids"],
  147. }
  148. if mm_data is not None:
  149. engine_inputs["multi_modal_data"] = mm_data
  150. result_generator = self.async_engine_client.generate(
  151. engine_inputs,
  152. sampling_params,
  153. request_id,
  154. lora_request=lora_request,
  155. prompt_adapter_request=prompt_adapter_request,
  156. )
  157. except ValueError as e:
  158. # TODO: Use an aphrodite-specific Validation Error
  159. return self.create_error_response(str(e))
  160. if raw_request:
  161. result_generator = iterate_with_cancellation(
  162. result_generator, raw_request.is_disconnected)
  163. # Streaming response
  164. if request.stream:
  165. return self.chat_completion_stream_generator(
  166. request, result_generator, request_id, conversation, tokenizer)
  167. try:
  168. return await self.chat_completion_full_generator(
  169. request, result_generator, request_id, conversation, tokenizer)
  170. except ValueError as e:
  171. # TODO: Use an aphrodite-specific Validation Error
  172. return self.create_error_response(str(e))
  173. def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
  174. if request.add_generation_prompt:
  175. return self.response_role
  176. else:
  177. return request.messages[-1]["role"]
  178. async def chat_completion_stream_generator(
  179. self,
  180. request: ChatCompletionRequest,
  181. result_generator: AsyncIterator[RequestOutput],
  182. request_id: str,
  183. conversation: List[ConversationMessage],
  184. tokenizer: PreTrainedTokenizer,
  185. ) -> AsyncGenerator[str, None]:
  186. model_name = self.served_model_names[0]
  187. created_time = int(time.time())
  188. chunk_object_type = "chat.completion.chunk"
  189. first_iteration = True
  190. # Send response for each token for each request.n (index)
  191. num_choices = 1 if request.n is None else request.n
  192. previous_texts = [""] * num_choices
  193. previous_num_tokens = [0] * num_choices
  194. finish_reason_sent = [False] * num_choices
  195. try:
  196. async for res in result_generator:
  197. # We need to do it here, because if there are exceptions in
  198. # the result_generator, it needs to be sent as the FIRST
  199. # response (by the try...catch).
  200. if first_iteration:
  201. # Send first response for each request.n (index) with
  202. # the role
  203. role = self.get_chat_request_role(request)
  204. for i in range(num_choices):
  205. choice_data = ChatCompletionResponseStreamChoice(
  206. index=i,
  207. delta=DeltaMessage(role=role),
  208. logprobs=None,
  209. finish_reason=None)
  210. chunk = ChatCompletionStreamResponse(
  211. id=request_id,
  212. object=chunk_object_type,
  213. created=created_time,
  214. choices=[choice_data],
  215. model=model_name)
  216. if (request.stream_options
  217. and request.stream_options.include_usage):
  218. if (request.stream_options.continuous_usage_stats):
  219. prompt_tokens = len(res.prompt_token_ids)
  220. usage = UsageInfo(prompt_tokens=prompt_tokens,
  221. completion_tokens=0,
  222. total_tokens=prompt_tokens)
  223. chunk.usage = usage
  224. else:
  225. chunk.usage = None
  226. data = chunk.model_dump_json(exclude_unset=True)
  227. yield f"data: {data}\n\n"
  228. # Send response to echo the input portion of the
  229. # last message
  230. if request.echo:
  231. last_msg_content = ""
  232. if conversation and conversation[-1].get(
  233. "content") and conversation[-1].get(
  234. "role") == role:
  235. last_msg_content = conversation[-1]["content"]
  236. if last_msg_content:
  237. for i in range(num_choices):
  238. choice_data = (
  239. ChatCompletionResponseStreamChoice(
  240. index=i,
  241. delta=DeltaMessage(
  242. content=last_msg_content),
  243. logprobs=None,
  244. finish_reason=None))
  245. chunk = ChatCompletionStreamResponse(
  246. id=request_id,
  247. object=chunk_object_type,
  248. created=created_time,
  249. choices=[choice_data],
  250. model=model_name)
  251. if (request.stream_options and
  252. request.stream_options.include_usage):
  253. if (request.stream_options.
  254. continuous_usage_stats):
  255. prompt_tokens = len(
  256. res.prompt_token_ids)
  257. usage = UsageInfo(
  258. prompt_tokens=prompt_tokens,
  259. completion_tokens=0,
  260. total_tokens=prompt_tokens)
  261. chunk.usage = usage
  262. else:
  263. chunk.usage = None
  264. data = chunk.model_dump_json(
  265. exclude_unset=True)
  266. yield f"data: {data}\n\n"
  267. first_iteration = False
  268. for output in res.outputs:
  269. i = output.index
  270. if finish_reason_sent[i]:
  271. continue
  272. delta_token_ids = output.token_ids[previous_num_tokens[i]:]
  273. out_logprobs = output.logprobs[
  274. previous_num_tokens[i]:] if output.logprobs else None
  275. if request.logprobs and request.top_logprobs is not None:
  276. assert out_logprobs is not None, (
  277. "Did not output logprobs")
  278. logprobs = self._create_chat_logprobs(
  279. token_ids=delta_token_ids,
  280. top_logprobs=out_logprobs,
  281. tokenizer=tokenizer,
  282. num_output_top_logprobs=request.top_logprobs,
  283. )
  284. else:
  285. logprobs = None
  286. delta_text = output.text[len(previous_texts[i]):]
  287. previous_texts[i] = output.text
  288. previous_num_tokens[i] = len(output.token_ids)
  289. if request.tool_choice and type(
  290. request.tool_choice
  291. ) is ChatCompletionNamedToolChoiceParam:
  292. delta_message = DeltaMessage(tool_calls=[
  293. ToolCall(function=FunctionCall(
  294. name=request.tool_choice.function.name,
  295. arguments=delta_text))
  296. ])
  297. else:
  298. delta_message = DeltaMessage(content=delta_text)
  299. if output.finish_reason is None:
  300. # Send token-by-token response for each request.n
  301. choice_data = ChatCompletionResponseStreamChoice(
  302. index=i,
  303. delta=delta_message,
  304. logprobs=logprobs,
  305. finish_reason=None)
  306. chunk = ChatCompletionStreamResponse(
  307. id=request_id,
  308. object=chunk_object_type,
  309. created=created_time,
  310. choices=[choice_data],
  311. model=model_name)
  312. if (request.stream_options
  313. and request.stream_options.include_usage):
  314. if (request.stream_options.continuous_usage_stats):
  315. prompt_tokens = len(res.prompt_token_ids)
  316. completion_tokens = len(output.token_ids)
  317. usage = UsageInfo(
  318. prompt_tokens=prompt_tokens,
  319. completion_tokens=completion_tokens,
  320. total_tokens=prompt_tokens +
  321. completion_tokens,
  322. )
  323. chunk.usage = usage
  324. else:
  325. chunk.usage = None
  326. data = chunk.model_dump_json(exclude_unset=True)
  327. yield f"data: {data}\n\n"
  328. else:
  329. # Send the finish response for each request.n only once
  330. prompt_tokens = len(res.prompt_token_ids)
  331. choice_data = ChatCompletionResponseStreamChoice(
  332. index=i,
  333. delta=delta_message,
  334. logprobs=logprobs,
  335. finish_reason=output.finish_reason,
  336. stop_reason=output.stop_reason)
  337. chunk = ChatCompletionStreamResponse(
  338. id=request_id,
  339. object=chunk_object_type,
  340. created=created_time,
  341. choices=[choice_data],
  342. model=model_name)
  343. if (request.stream_options
  344. and request.stream_options.include_usage):
  345. if (request.stream_options.continuous_usage_stats):
  346. prompt_tokens = len(res.prompt_token_ids)
  347. completion_tokens = len(output.token_ids)
  348. usage = UsageInfo(
  349. prompt_tokens=prompt_tokens,
  350. completion_tokens=completion_tokens,
  351. total_tokens=prompt_tokens +
  352. completion_tokens,
  353. )
  354. chunk.usage = usage
  355. else:
  356. chunk.usage = None
  357. data = chunk.model_dump_json(exclude_unset=True)
  358. yield f"data: {data}\n\n"
  359. finish_reason_sent[i] = True
  360. if (request.stream_options
  361. and request.stream_options.include_usage):
  362. final_usage = UsageInfo(
  363. prompt_tokens=prompt_tokens,
  364. completion_tokens=previous_num_tokens[i],
  365. total_tokens=prompt_tokens + previous_num_tokens[i],
  366. )
  367. final_usage_chunk = ChatCompletionStreamResponse(
  368. id=request_id,
  369. object=chunk_object_type,
  370. created=created_time,
  371. choices=[],
  372. model=model_name,
  373. usage=final_usage)
  374. final_usage_data = (final_usage_chunk.model_dump_json(
  375. exclude_unset=True, exclude_none=True))
  376. yield f"data: {final_usage_data}\n\n"
  377. except ValueError as e:
  378. # TODO: Use an aphrodite-specific Validation Error
  379. data = self.create_streaming_error_response(str(e))
  380. yield f"data: {data}\n\n"
  381. # Send the final done message after all response.n are finished
  382. yield "data: [DONE]\n\n"
  383. async def chat_completion_full_generator(
  384. self,
  385. request: ChatCompletionRequest,
  386. result_generator: AsyncIterator[RequestOutput],
  387. request_id: str,
  388. conversation: List[ConversationMessage],
  389. tokenizer: PreTrainedTokenizer,
  390. ) -> Union[ErrorResponse, ChatCompletionResponse]:
  391. model_name = self.served_model_names[0]
  392. created_time = int(time.time())
  393. final_res: Optional[RequestOutput] = None
  394. try:
  395. async for res in result_generator:
  396. final_res = res
  397. except asyncio.CancelledError:
  398. return self.create_error_response("Client disconnected")
  399. assert final_res is not None
  400. choices: List[ChatCompletionResponseChoice] = []
  401. role = self.get_chat_request_role(request)
  402. for output in final_res.outputs:
  403. token_ids = output.token_ids
  404. out_logprobs = output.logprobs
  405. if request.logprobs and request.top_logprobs is not None:
  406. assert out_logprobs is not None, "Did not output logprobs"
  407. logprobs = self._create_chat_logprobs(
  408. token_ids=token_ids,
  409. top_logprobs=out_logprobs,
  410. tokenizer=tokenizer,
  411. num_output_top_logprobs=request.top_logprobs,
  412. )
  413. else:
  414. logprobs = None
  415. if request.tool_choice and type(
  416. request.tool_choice) is ChatCompletionNamedToolChoiceParam:
  417. message = ChatMessage(
  418. role=role,
  419. content="",
  420. tool_calls=[
  421. ToolCall(function=FunctionCall(
  422. name=request.tool_choice.function.name,
  423. arguments=output.text))
  424. ])
  425. elif not request.tool_choice or request.tool_choice == "none":
  426. message = ChatMessage(role=role, content=output.text)
  427. choice_data = ChatCompletionResponseChoice(
  428. index=output.index,
  429. message=message,
  430. logprobs=logprobs,
  431. finish_reason=output.finish_reason,
  432. stop_reason=output.stop_reason)
  433. choices.append(choice_data)
  434. if request.echo:
  435. last_msg_content = ""
  436. if conversation and conversation[-1].get(
  437. "content") and conversation[-1].get("role") == role:
  438. last_msg_content = conversation[-1]["content"]
  439. for choice in choices:
  440. full_message = last_msg_content + choice.message.content
  441. choice.message.content = full_message
  442. num_prompt_tokens = len(final_res.prompt_token_ids)
  443. num_generated_tokens = sum(
  444. len(output.token_ids) for output in final_res.outputs)
  445. usage = UsageInfo(
  446. prompt_tokens=num_prompt_tokens,
  447. completion_tokens=num_generated_tokens,
  448. total_tokens=num_prompt_tokens + num_generated_tokens,
  449. )
  450. response = ChatCompletionResponse(
  451. id=request_id,
  452. created=created_time,
  453. model=model_name,
  454. choices=choices,
  455. usage=usage,
  456. prompt_logprobs=final_res.prompt_logprobs,
  457. )
  458. return response
  459. def _get_top_logprobs(
  460. self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
  461. tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
  462. return [
  463. ChatCompletionLogProb(token=(token := self._get_decoded_token(
  464. p[1],
  465. p[0],
  466. tokenizer,
  467. return_as_token_id=self.return_tokens_as_token_ids)),
  468. logprob=max(p[1].logprob, -9999.0),
  469. bytes=list(
  470. token.encode("utf-8", errors="replace")))
  471. for i, p in enumerate(logprobs.items())
  472. if top_logprobs and i < top_logprobs
  473. ]
  474. def _create_chat_logprobs(
  475. self,
  476. token_ids: GenericSequence[int],
  477. top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
  478. tokenizer: PreTrainedTokenizer,
  479. num_output_top_logprobs: Optional[int] = None,
  480. ) -> ChatCompletionLogProbs:
  481. """Create OpenAI-style logprobs."""
  482. logprobs_content = []
  483. for i, token_id in enumerate(token_ids):
  484. step_top_logprobs = top_logprobs[i]
  485. if step_top_logprobs is None:
  486. token = tokenizer.decode(token_id)
  487. if self.return_tokens_as_token_ids:
  488. token = f"token_id:{token_id}"
  489. logprobs_content.append(
  490. ChatCompletionLogProbsContent(
  491. token=token,
  492. bytes=list(token.encode("utf-8", errors="replace"))))
  493. else:
  494. logprobs_content.append(
  495. ChatCompletionLogProbsContent(
  496. token=self._get_decoded_token(
  497. step_top_logprobs[token_id], token_id, tokenizer,
  498. self.return_tokens_as_token_ids),
  499. logprob=max(step_top_logprobs[token_id].logprob,
  500. -9999.0),
  501. bytes=list(
  502. step_top_logprobs[token_id].decoded_token.encode(
  503. "utf-8", errors="replace")),
  504. top_logprobs=self._get_top_logprobs(
  505. step_top_logprobs, num_output_top_logprobs,
  506. tokenizer)))
  507. return ChatCompletionLogProbs(content=logprobs_content)