serving_chat.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545
  1. import asyncio
  2. import time
  3. from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
  4. from typing import Sequence as GenericSequence
  5. from typing import Union
  6. from fastapi import Request
  7. from loguru import logger
  8. from transformers import PreTrainedTokenizer
  9. from aphrodite.common.config import ModelConfig
  10. from aphrodite.common.outputs import RequestOutput
  11. from aphrodite.common.sequence import Logprob
  12. from aphrodite.common.utils import iterate_with_cancellation, random_uuid
  13. from aphrodite.endpoints.chat_utils import (ConversationMessage,
  14. apply_chat_template,
  15. load_chat_template,
  16. parse_chat_messages)
  17. from aphrodite.endpoints.logger import RequestLogger
  18. from aphrodite.endpoints.openai.protocol import (
  19. ChatCompletionLogProb, ChatCompletionLogProbs,
  20. ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
  21. ChatCompletionRequest, ChatCompletionResponse,
  22. ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
  23. ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
  24. FunctionCall, ToolCall, UsageInfo)
  25. from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
  26. OpenAIServing,
  27. PromptAdapterPath)
  28. from aphrodite.engine.protocol import AsyncEngineClient
  29. from aphrodite.inputs import PromptInputs
  30. from aphrodite.multimodal import MultiModalDataDict
  31. class OpenAIServingChat(OpenAIServing):
  32. def __init__(
  33. self,
  34. async_engine_client: AsyncEngineClient,
  35. model_config: ModelConfig,
  36. served_model_names: List[str],
  37. response_role: str,
  38. *,
  39. lora_modules: Optional[List[LoRAModulePath]],
  40. prompt_adapters: Optional[List[PromptAdapterPath]],
  41. request_logger: Optional[RequestLogger],
  42. chat_template: Optional[str],
  43. return_tokens_as_token_ids: bool = False,
  44. ):
  45. super().__init__(async_engine_client=async_engine_client,
  46. model_config=model_config,
  47. served_model_names=served_model_names,
  48. lora_modules=lora_modules,
  49. prompt_adapters=prompt_adapters,
  50. request_logger=request_logger,
  51. return_tokens_as_token_ids=return_tokens_as_token_ids)
  52. self.response_role = response_role
  53. # If this is None we use the tokenizer's default chat template
  54. self.chat_template = load_chat_template(chat_template)
  55. async def create_chat_completion(
  56. self,
  57. request: ChatCompletionRequest,
  58. raw_request: Optional[Request] = None
  59. ) -> Union[ErrorResponse, AsyncGenerator[str, None],
  60. ChatCompletionResponse]:
  61. """Completion API similar to OpenAI's API.
  62. See https://platform.openai.com/docs/api-reference/chat/create
  63. for the API specification. This API mimics the OpenAI
  64. ChatCompletion API.
  65. NOTE: Currently we do not support the following feature:
  66. - function_call (Users should implement this by themselves)
  67. """
  68. error_check_ret = await self._check_model(request)
  69. if error_check_ret is not None:
  70. return error_check_ret
  71. try:
  72. (
  73. lora_request,
  74. prompt_adapter_request,
  75. ) = self._maybe_get_adapters(request)
  76. model_config = self.model_config
  77. tokenizer = await self.async_engine_client.get_tokenizer(
  78. lora_request)
  79. conversation, mm_futures = parse_chat_messages(
  80. request.messages, model_config, tokenizer)
  81. tool_dicts = None if request.tools is None else [
  82. tool.model_dump() for tool in request.tools
  83. ]
  84. prompt = apply_chat_template(
  85. tokenizer,
  86. conversation=conversation,
  87. chat_template=request.chat_template or self.chat_template,
  88. add_generation_prompt=request.add_generation_prompt,
  89. tools=tool_dicts,
  90. documents=request.documents,
  91. **(request.chat_template_kwargs or {}),
  92. )
  93. except Exception as e:
  94. logger.error(f"Error in applying chat template from request: {e}")
  95. return self.create_error_response(str(e))
  96. mm_data: Optional[MultiModalDataDict] = None
  97. try:
  98. if len(mm_futures):
  99. # since we support only single mm data currently
  100. assert len(
  101. mm_futures
  102. ) == 1, "Multiple 'image_url' input is currently not supported."
  103. mm_data = await mm_futures[0]
  104. except Exception as e:
  105. logger.error(f"Error in loading multi-modal data: {e}")
  106. return self.create_error_response(str(e))
  107. request_id = f"chat-{random_uuid()}"
  108. try:
  109. guided_decode_logits_processor = (
  110. await self._guided_decode_logits_processor(request, tokenizer))
  111. prompt_inputs = self._tokenize_prompt_input(
  112. request,
  113. tokenizer,
  114. prompt,
  115. truncate_prompt_tokens=request.truncate_prompt_tokens,
  116. add_special_tokens=request.add_special_tokens,
  117. )
  118. sampling_params = request.to_sampling_params(
  119. tokenizer,
  120. guided_decode_logits_processor,
  121. default_max_tokens=self.max_model_len -
  122. len(prompt_inputs["prompt_token_ids"]))
  123. self._log_inputs(request_id,
  124. prompt_inputs,
  125. params=sampling_params,
  126. lora_request=lora_request,
  127. prompt_adapter_request=prompt_adapter_request)
  128. engine_inputs: PromptInputs = {
  129. "prompt_token_ids": prompt_inputs["prompt_token_ids"],
  130. }
  131. if mm_data is not None:
  132. engine_inputs["multi_modal_data"] = mm_data
  133. result_generator = self.async_engine_client.generate(
  134. engine_inputs,
  135. sampling_params,
  136. request_id,
  137. lora_request=lora_request,
  138. prompt_adapter_request=prompt_adapter_request,
  139. )
  140. except ValueError as e:
  141. # TODO: Use an aphrodite-specific Validation Error
  142. return self.create_error_response(str(e))
  143. if raw_request:
  144. result_generator = iterate_with_cancellation(
  145. result_generator, raw_request.is_disconnected)
  146. # Streaming response
  147. if request.stream:
  148. return self.chat_completion_stream_generator(
  149. request, result_generator, request_id, conversation, tokenizer)
  150. try:
  151. return await self.chat_completion_full_generator(
  152. request, result_generator, request_id, conversation, tokenizer)
  153. except ValueError as e:
  154. # TODO: Use an aphrodite-specific Validation Error
  155. return self.create_error_response(str(e))
  156. def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
  157. if request.add_generation_prompt:
  158. return self.response_role
  159. else:
  160. return request.messages[-1]["role"]
  161. async def chat_completion_stream_generator(
  162. self,
  163. request: ChatCompletionRequest,
  164. result_generator: AsyncIterator[RequestOutput],
  165. request_id: str,
  166. conversation: List[ConversationMessage],
  167. tokenizer: PreTrainedTokenizer,
  168. ) -> AsyncGenerator[str, None]:
  169. model_name = self.served_model_names[0]
  170. created_time = int(time.time())
  171. chunk_object_type = "chat.completion.chunk"
  172. first_iteration = True
  173. # Send response for each token for each request.n (index)
  174. num_choices = 1 if request.n is None else request.n
  175. previous_texts = [""] * num_choices
  176. previous_num_tokens = [0] * num_choices
  177. finish_reason_sent = [False] * num_choices
  178. try:
  179. async for res in result_generator:
  180. # We need to do it here, because if there are exceptions in
  181. # the result_generator, it needs to be sent as the FIRST
  182. # response (by the try...catch).
  183. if first_iteration:
  184. # Send first response for each request.n (index) with
  185. # the role
  186. role = self.get_chat_request_role(request)
  187. for i in range(num_choices):
  188. choice_data = ChatCompletionResponseStreamChoice(
  189. index=i,
  190. delta=DeltaMessage(role=role),
  191. logprobs=None,
  192. finish_reason=None)
  193. chunk = ChatCompletionStreamResponse(
  194. id=request_id,
  195. object=chunk_object_type,
  196. created=created_time,
  197. choices=[choice_data],
  198. model=model_name)
  199. if (request.stream_options
  200. and request.stream_options.include_usage):
  201. if (request.stream_options.continuous_usage_stats):
  202. prompt_tokens = len(res.prompt_token_ids)
  203. usage = UsageInfo(prompt_tokens=prompt_tokens,
  204. completion_tokens=0,
  205. total_tokens=prompt_tokens)
  206. chunk.usage = usage
  207. else:
  208. chunk.usage = None
  209. data = chunk.model_dump_json(exclude_unset=True)
  210. yield f"data: {data}\n\n"
  211. # Send response to echo the input portion of the
  212. # last message
  213. if request.echo:
  214. last_msg_content = ""
  215. if conversation and conversation[-1].get(
  216. "content") and conversation[-1].get(
  217. "role") == role:
  218. last_msg_content = conversation[-1]["content"]
  219. if last_msg_content:
  220. for i in range(num_choices):
  221. choice_data = (
  222. ChatCompletionResponseStreamChoice(
  223. index=i,
  224. delta=DeltaMessage(
  225. content=last_msg_content),
  226. logprobs=None,
  227. finish_reason=None))
  228. chunk = ChatCompletionStreamResponse(
  229. id=request_id,
  230. object=chunk_object_type,
  231. created=created_time,
  232. choices=[choice_data],
  233. model=model_name)
  234. if (request.stream_options and
  235. request.stream_options.include_usage):
  236. if (request.stream_options.
  237. continuous_usage_stats):
  238. prompt_tokens = len(
  239. res.prompt_token_ids)
  240. usage = UsageInfo(
  241. prompt_tokens=prompt_tokens,
  242. completion_tokens=0,
  243. total_tokens=prompt_tokens)
  244. chunk.usage = usage
  245. else:
  246. chunk.usage = None
  247. data = chunk.model_dump_json(
  248. exclude_unset=True)
  249. yield f"data: {data}\n\n"
  250. first_iteration = False
  251. for output in res.outputs:
  252. i = output.index
  253. if finish_reason_sent[i]:
  254. continue
  255. delta_token_ids = output.token_ids[previous_num_tokens[i]:]
  256. out_logprobs = output.logprobs[
  257. previous_num_tokens[i]:] if output.logprobs else None
  258. if request.logprobs and request.top_logprobs is not None:
  259. assert out_logprobs is not None, (
  260. "Did not output logprobs")
  261. logprobs = self._create_chat_logprobs(
  262. token_ids=delta_token_ids,
  263. top_logprobs=out_logprobs,
  264. tokenizer=tokenizer,
  265. num_output_top_logprobs=request.top_logprobs,
  266. )
  267. else:
  268. logprobs = None
  269. delta_text = output.text[len(previous_texts[i]):]
  270. previous_texts[i] = output.text
  271. previous_num_tokens[i] = len(output.token_ids)
  272. if request.tool_choice and type(
  273. request.tool_choice
  274. ) is ChatCompletionNamedToolChoiceParam:
  275. delta_message = DeltaMessage(tool_calls=[
  276. ToolCall(function=FunctionCall(
  277. name=request.tool_choice.function.name,
  278. arguments=delta_text))
  279. ])
  280. else:
  281. delta_message = DeltaMessage(content=delta_text)
  282. if output.finish_reason is None:
  283. # Send token-by-token response for each request.n
  284. choice_data = ChatCompletionResponseStreamChoice(
  285. index=i,
  286. delta=delta_message,
  287. logprobs=logprobs,
  288. finish_reason=None)
  289. chunk = ChatCompletionStreamResponse(
  290. id=request_id,
  291. object=chunk_object_type,
  292. created=created_time,
  293. choices=[choice_data],
  294. model=model_name)
  295. if (request.stream_options
  296. and request.stream_options.include_usage):
  297. if (request.stream_options.continuous_usage_stats):
  298. prompt_tokens = len(res.prompt_token_ids)
  299. completion_tokens = len(output.token_ids)
  300. usage = UsageInfo(
  301. prompt_tokens=prompt_tokens,
  302. completion_tokens=completion_tokens,
  303. total_tokens=prompt_tokens +
  304. completion_tokens,
  305. )
  306. chunk.usage = usage
  307. else:
  308. chunk.usage = None
  309. data = chunk.model_dump_json(exclude_unset=True)
  310. yield f"data: {data}\n\n"
  311. else:
  312. # Send the finish response for each request.n only once
  313. prompt_tokens = len(res.prompt_token_ids)
  314. choice_data = ChatCompletionResponseStreamChoice(
  315. index=i,
  316. delta=delta_message,
  317. logprobs=logprobs,
  318. finish_reason=output.finish_reason,
  319. stop_reason=output.stop_reason)
  320. chunk = ChatCompletionStreamResponse(
  321. id=request_id,
  322. object=chunk_object_type,
  323. created=created_time,
  324. choices=[choice_data],
  325. model=model_name)
  326. if (request.stream_options
  327. and request.stream_options.include_usage):
  328. if (request.stream_options.continuous_usage_stats):
  329. prompt_tokens = len(res.prompt_token_ids)
  330. completion_tokens = len(output.token_ids)
  331. usage = UsageInfo(
  332. prompt_tokens=prompt_tokens,
  333. completion_tokens=completion_tokens,
  334. total_tokens=prompt_tokens +
  335. completion_tokens,
  336. )
  337. chunk.usage = usage
  338. else:
  339. chunk.usage = None
  340. data = chunk.model_dump_json(exclude_unset=True)
  341. yield f"data: {data}\n\n"
  342. finish_reason_sent[i] = True
  343. if (request.stream_options
  344. and request.stream_options.include_usage):
  345. final_usage = UsageInfo(
  346. prompt_tokens=prompt_tokens,
  347. completion_tokens=previous_num_tokens[i],
  348. total_tokens=prompt_tokens + previous_num_tokens[i],
  349. )
  350. final_usage_chunk = ChatCompletionStreamResponse(
  351. id=request_id,
  352. object=chunk_object_type,
  353. created=created_time,
  354. choices=[],
  355. model=model_name,
  356. usage=final_usage)
  357. final_usage_data = (final_usage_chunk.model_dump_json(
  358. exclude_unset=True, exclude_none=True))
  359. yield f"data: {final_usage_data}\n\n"
  360. except ValueError as e:
  361. # TODO: Use an aphrodite-specific Validation Error
  362. data = self.create_streaming_error_response(str(e))
  363. yield f"data: {data}\n\n"
  364. # Send the final done message after all response.n are finished
  365. yield "data: [DONE]\n\n"
  366. async def chat_completion_full_generator(
  367. self,
  368. request: ChatCompletionRequest,
  369. result_generator: AsyncIterator[RequestOutput],
  370. request_id: str,
  371. conversation: List[ConversationMessage],
  372. tokenizer: PreTrainedTokenizer,
  373. ) -> Union[ErrorResponse, ChatCompletionResponse]:
  374. model_name = self.served_model_names[0]
  375. created_time = int(time.time())
  376. final_res: Optional[RequestOutput] = None
  377. try:
  378. async for res in result_generator:
  379. final_res = res
  380. except asyncio.CancelledError:
  381. return self.create_error_response("Client disconnected")
  382. assert final_res is not None
  383. choices: List[ChatCompletionResponseChoice] = []
  384. role = self.get_chat_request_role(request)
  385. for output in final_res.outputs:
  386. token_ids = output.token_ids
  387. out_logprobs = output.logprobs
  388. if request.logprobs and request.top_logprobs is not None:
  389. assert out_logprobs is not None, "Did not output logprobs"
  390. logprobs = self._create_chat_logprobs(
  391. token_ids=token_ids,
  392. top_logprobs=out_logprobs,
  393. tokenizer=tokenizer,
  394. num_output_top_logprobs=request.top_logprobs,
  395. )
  396. else:
  397. logprobs = None
  398. if request.tool_choice and type(
  399. request.tool_choice) is ChatCompletionNamedToolChoiceParam:
  400. message = ChatMessage(
  401. role=role,
  402. content="",
  403. tool_calls=[
  404. ToolCall(function=FunctionCall(
  405. name=request.tool_choice.function.name,
  406. arguments=output.text))
  407. ])
  408. elif not request.tool_choice or request.tool_choice == "none":
  409. message = ChatMessage(role=role, content=output.text)
  410. choice_data = ChatCompletionResponseChoice(
  411. index=output.index,
  412. message=message,
  413. logprobs=logprobs,
  414. finish_reason=output.finish_reason,
  415. stop_reason=output.stop_reason)
  416. choices.append(choice_data)
  417. if request.echo:
  418. last_msg_content = ""
  419. if conversation and conversation[-1].get(
  420. "content") and conversation[-1].get("role") == role:
  421. last_msg_content = conversation[-1]["content"]
  422. for choice in choices:
  423. full_message = last_msg_content + choice.message.content
  424. choice.message.content = full_message
  425. num_prompt_tokens = len(final_res.prompt_token_ids)
  426. num_generated_tokens = sum(
  427. len(output.token_ids) for output in final_res.outputs)
  428. usage = UsageInfo(
  429. prompt_tokens=num_prompt_tokens,
  430. completion_tokens=num_generated_tokens,
  431. total_tokens=num_prompt_tokens + num_generated_tokens,
  432. )
  433. response = ChatCompletionResponse(
  434. id=request_id,
  435. created=created_time,
  436. model=model_name,
  437. choices=choices,
  438. usage=usage,
  439. )
  440. return response
  441. def _get_top_logprobs(
  442. self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
  443. tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
  444. return [
  445. ChatCompletionLogProb(token=(token := self._get_decoded_token(
  446. p[1],
  447. p[0],
  448. tokenizer,
  449. return_as_token_id=self.return_tokens_as_token_ids)),
  450. logprob=max(p[1].logprob, -9999.0),
  451. bytes=list(
  452. token.encode("utf-8", errors="replace")))
  453. for i, p in enumerate(logprobs.items())
  454. if top_logprobs and i < top_logprobs
  455. ]
  456. def _create_chat_logprobs(
  457. self,
  458. token_ids: GenericSequence[int],
  459. top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
  460. tokenizer: PreTrainedTokenizer,
  461. num_output_top_logprobs: Optional[int] = None,
  462. ) -> ChatCompletionLogProbs:
  463. """Create OpenAI-style logprobs."""
  464. logprobs_content = []
  465. for i, token_id in enumerate(token_ids):
  466. step_top_logprobs = top_logprobs[i]
  467. if step_top_logprobs is None:
  468. token = tokenizer.decode(token_id)
  469. if self.return_tokens_as_token_ids:
  470. token = f"token_id:{token_id}"
  471. logprobs_content.append(
  472. ChatCompletionLogProbsContent(
  473. token=token,
  474. bytes=list(token.encode("utf-8", errors="replace"))))
  475. else:
  476. logprobs_content.append(
  477. ChatCompletionLogProbsContent(
  478. token=self._get_decoded_token(
  479. step_top_logprobs[token_id], token_id, tokenizer,
  480. self.return_tokens_as_token_ids),
  481. logprob=max(step_top_logprobs[token_id].logprob,
  482. -9999.0),
  483. bytes=list(
  484. step_top_logprobs[token_id].decoded_token.encode(
  485. "utf-8", errors="replace")),
  486. top_logprobs=self._get_top_logprobs(
  487. step_top_logprobs, num_output_top_logprobs,
  488. tokenizer)))
  489. return ChatCompletionLogProbs(content=logprobs_content)