serving_chat.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810
  1. import asyncio
  2. import json
  3. import time
  4. from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List,
  5. Optional)
  6. from typing import Sequence as GenericSequence
  7. from typing import Union
  8. from fastapi import Request
  9. from loguru import logger
  10. from aphrodite.common.config import ModelConfig
  11. from aphrodite.common.outputs import CompletionOutput, RequestOutput
  12. from aphrodite.common.sequence import Logprob
  13. from aphrodite.common.utils import iterate_with_cancellation, random_uuid
  14. from aphrodite.endpoints.chat_utils import (ConversationMessage,
  15. apply_hf_chat_template,
  16. apply_mistral_chat_template,
  17. load_chat_template,
  18. parse_chat_messages_futures)
  19. from aphrodite.endpoints.logger import RequestLogger
  20. from aphrodite.endpoints.openai.protocol import (
  21. ChatCompletionLogProb, ChatCompletionLogProbs,
  22. ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
  23. ChatCompletionRequest, ChatCompletionResponse,
  24. ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
  25. ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
  26. DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo)
  27. from aphrodite.endpoints.openai.serving_engine import (BaseModelPath,
  28. LoRAModulePath,
  29. OpenAIServing,
  30. PromptAdapterPath,
  31. TextTokensPrompt)
  32. from aphrodite.endpoints.openai.tool_parsers import (Hermes2ProToolParser,
  33. MistralToolParser,
  34. ToolParser)
  35. from aphrodite.engine.protocol import EngineClient
  36. from aphrodite.inputs import TokensPrompt
  37. from aphrodite.transformers_utils.tokenizer import (AnyTokenizer,
  38. MistralTokenizer)
  39. class OpenAIServingChat(OpenAIServing):
  40. def __init__(self,
  41. engine_client: EngineClient,
  42. model_config: ModelConfig,
  43. base_model_paths: List[BaseModelPath],
  44. response_role: str,
  45. *,
  46. lora_modules: Optional[List[LoRAModulePath]],
  47. prompt_adapters: Optional[List[PromptAdapterPath]],
  48. request_logger: Optional[RequestLogger],
  49. chat_template: Optional[str],
  50. return_tokens_as_token_ids: bool = False,
  51. enable_auto_tools: bool = False,
  52. tool_parser: Optional[str] = None):
  53. super().__init__(engine_client=engine_client,
  54. model_config=model_config,
  55. base_model_paths=base_model_paths,
  56. lora_modules=lora_modules,
  57. prompt_adapters=prompt_adapters,
  58. request_logger=request_logger,
  59. return_tokens_as_token_ids=return_tokens_as_token_ids)
  60. self.response_role = response_role
  61. self.use_tool_use_model_template = False
  62. self.chat_template = load_chat_template(chat_template)
  63. # set up tool use
  64. self.enable_auto_tools: bool = enable_auto_tools
  65. if self.enable_auto_tools:
  66. logger.info(
  67. "\"auto\" tool choice has been enabled please note that while"
  68. " the parallel_tool_calls client option is preset for "
  69. "compatibility reasons, it will be ignored.")
  70. self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
  71. if self.enable_auto_tools:
  72. if tool_parser == "mistral":
  73. self.tool_parser = MistralToolParser
  74. elif tool_parser == "hermes":
  75. self.tool_parser = Hermes2ProToolParser
  76. else:
  77. raise TypeError("Error: --enable-auto-tool-choice requires "
  78. "--tool-call-parser")
  79. async def create_chat_completion(
  80. self,
  81. request: ChatCompletionRequest,
  82. raw_request: Optional[Request] = None,
  83. ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
  84. ErrorResponse]:
  85. """Completion API similar to OpenAI's API.
  86. See https://platform.openai.com/docs/api-reference/chat/create
  87. for the API specification. This API mimics the OpenAI
  88. ChatCompletion API.
  89. """
  90. error_check_ret = await self._check_model(request)
  91. if error_check_ret is not None:
  92. logger.error(f"Error with model {error_check_ret}")
  93. return error_check_ret
  94. # If the engine is dead, raise the engine's DEAD_ERROR.
  95. # This is required for the streaming case, where we return a
  96. # success status before we actually start generating text :).
  97. if self.engine_client.errored:
  98. raise self.engine_client.dead_error
  99. try:
  100. (
  101. lora_request,
  102. prompt_adapter_request,
  103. ) = self._maybe_get_adapters(request)
  104. model_config = self.model_config
  105. tokenizer = await self.engine_client.get_tokenizer(lora_request)
  106. conversation, mm_data_future = parse_chat_messages_futures(
  107. request.messages, model_config, tokenizer)
  108. tool_dicts = None if request.tools is None else [
  109. tool.model_dump() for tool in request.tools
  110. ]
  111. prompt: Union[str, List[int]]
  112. is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
  113. if is_mistral_tokenizer:
  114. prompt = apply_mistral_chat_template(
  115. tokenizer,
  116. messages=request.messages,
  117. chat_template=request.chat_template or self.chat_template,
  118. add_generation_prompt=request.add_generation_prompt,
  119. tools=tool_dicts,
  120. documents=request.documents,
  121. **(request.chat_template_kwargs or {}),
  122. )
  123. else:
  124. prompt = apply_hf_chat_template(
  125. tokenizer,
  126. conversation=conversation,
  127. chat_template=request.chat_template or self.chat_template,
  128. add_generation_prompt=request.add_generation_prompt,
  129. tools=tool_dicts,
  130. documents=request.documents,
  131. **(request.chat_template_kwargs or {}),
  132. )
  133. except Exception as e:
  134. logger.error(f"Error in applying chat template from request: {e}")
  135. return self.create_error_response(str(e))
  136. try:
  137. mm_data = await mm_data_future
  138. except Exception as e:
  139. logger.error(f"Error in loading multi-modal data: {e}")
  140. return self.create_error_response(str(e))
  141. # validation for OpenAI tools
  142. # tool_choice = "required" is not supported
  143. if request.tool_choice == "required":
  144. return self.create_error_response(
  145. "tool_choice = \"required\" is not supported!")
  146. if not is_mistral_tokenizer and request.tool_choice == "auto" and not (
  147. self.enable_auto_tools and self.tool_parser is not None):
  148. # for hf tokenizers, "auto" tools requires
  149. # --enable-auto-tool-choice and --tool-call-parser
  150. return self.create_error_response(
  151. "\"auto\" tool choice requires "
  152. "--enable-auto-tool-choice and --tool-call-parser to be set")
  153. request_id = f"chat-{random_uuid()}"
  154. try:
  155. guided_decode_logits_processor = (
  156. await self._guided_decode_logits_processor(request, tokenizer))
  157. if isinstance(prompt, str):
  158. prompt_inputs = self._tokenize_prompt_input(
  159. request,
  160. tokenizer,
  161. prompt,
  162. truncate_prompt_tokens=request.truncate_prompt_tokens,
  163. add_special_tokens=request.add_special_tokens,
  164. )
  165. else:
  166. assert isinstance(prompt, list) and isinstance(
  167. prompt[0], int
  168. ), "Prompt has to be either a string or a list of token ids"
  169. prompt_inputs = TextTokensPrompt(
  170. prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)
  171. assert prompt_inputs is not None
  172. sampling_params = request.to_sampling_params(
  173. tokenizer,
  174. guided_decode_logits_processor,
  175. default_max_tokens=self.max_model_len -
  176. len(prompt_inputs["prompt_token_ids"]))
  177. self._log_inputs(request_id,
  178. prompt_inputs,
  179. params=sampling_params,
  180. lora_request=lora_request,
  181. prompt_adapter_request=prompt_adapter_request)
  182. engine_inputs = TokensPrompt(
  183. prompt_token_ids=prompt_inputs["prompt_token_ids"])
  184. if mm_data is not None:
  185. engine_inputs["multi_modal_data"] = mm_data
  186. result_generator = self.engine_client.generate(
  187. engine_inputs,
  188. sampling_params,
  189. request_id,
  190. lora_request=lora_request,
  191. prompt_adapter_request=prompt_adapter_request,
  192. )
  193. except ValueError as e:
  194. # TODO: Use an aphrodite-specific Validation Error
  195. return self.create_error_response(str(e))
  196. if raw_request:
  197. result_generator = iterate_with_cancellation(
  198. result_generator, raw_request.is_disconnected)
  199. # Streaming response
  200. if request.stream:
  201. return self.chat_completion_stream_generator(
  202. request, result_generator, request_id, conversation, tokenizer)
  203. try:
  204. return await self.chat_completion_full_generator(
  205. request, result_generator, request_id, conversation, tokenizer)
  206. except ValueError as e:
  207. # TODO: Use an aphrodite-specific Validation Error
  208. return self.create_error_response(str(e))
  209. def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
  210. if request.add_generation_prompt:
  211. return self.response_role
  212. return request.messages[-1]["role"]
  213. async def chat_completion_stream_generator(
  214. self,
  215. request: ChatCompletionRequest,
  216. result_generator: AsyncIterator[RequestOutput],
  217. request_id: str,
  218. conversation: List[ConversationMessage],
  219. tokenizer: AnyTokenizer,
  220. ) -> AsyncGenerator[str, None]:
  221. model_name = self.base_model_paths[0].name
  222. created_time = int(time.time())
  223. chunk_object_type: Final = "chat.completion.chunk"
  224. first_iteration = True
  225. # Send response for each token for each request.n (index)
  226. num_choices = 1 if request.n is None else request.n
  227. previous_num_tokens = [0] * num_choices
  228. finish_reason_sent = [False] * num_choices
  229. num_prompt_tokens = 0
  230. tool_parser: Optional[ToolParser] = self.tool_parser(
  231. tokenizer) if self.tool_parser else None
  232. if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
  233. tool_choice_function_name = request.tool_choice.function.name
  234. else:
  235. tool_choice_function_name = None
  236. # Determine whether tools are in use with "auto" tool choice
  237. tool_choice_auto = (
  238. not tool_choice_function_name
  239. and self._should_stream_with_auto_tool_parsing(request))
  240. all_previous_token_ids: Optional[List[List[int]]]
  241. if tool_choice_auto:
  242. # These are only required in "auto" tool choice case
  243. previous_texts = [""] * num_choices
  244. all_previous_token_ids = [[]] * num_choices
  245. else:
  246. previous_texts, all_previous_token_ids = None, None
  247. try:
  248. async for res in result_generator:
  249. if res.prompt_token_ids is not None:
  250. num_prompt_tokens = len(res.prompt_token_ids)
  251. # We need to do it here, because if there are exceptions in
  252. # the result_generator, it needs to be sent as the FIRST
  253. # response (by the try...catch).
  254. if first_iteration:
  255. # Send first response for each request.n (index) with
  256. # the role
  257. role = self.get_chat_request_role(request)
  258. # NOTE num_choices defaults to 1 so this usually executes
  259. # once per request
  260. for i in range(num_choices):
  261. choice_data = ChatCompletionResponseStreamChoice(
  262. index=i,
  263. delta=DeltaMessage(
  264. role=role,
  265. content="",
  266. ),
  267. logprobs=None,
  268. finish_reason=None)
  269. chunk = ChatCompletionStreamResponse(
  270. id=request_id,
  271. object=chunk_object_type,
  272. created=created_time,
  273. choices=[choice_data],
  274. model=model_name)
  275. # if usage should be included
  276. if (request.stream_options
  277. and request.stream_options.include_usage):
  278. # if continuous usage stats are requested, add it
  279. if request.stream_options.continuous_usage_stats:
  280. usage = UsageInfo(
  281. prompt_tokens=num_prompt_tokens,
  282. completion_tokens=0,
  283. total_tokens=num_prompt_tokens)
  284. chunk.usage = usage
  285. # otherwise don't
  286. else:
  287. chunk.usage = None
  288. data = chunk.model_dump_json(exclude_unset=True)
  289. yield f"data: {data}\n\n"
  290. # Send response to echo the input portion of the
  291. # last message
  292. if request.echo:
  293. last_msg_content: str = ""
  294. if conversation and "content" in conversation[
  295. -1] and conversation[-1].get("role") == role:
  296. last_msg_content = conversation[-1]["content"] or ""
  297. if last_msg_content:
  298. for i in range(num_choices):
  299. choice_data = (
  300. ChatCompletionResponseStreamChoice(
  301. index=i,
  302. delta=DeltaMessage(
  303. content=last_msg_content),
  304. logprobs=None,
  305. finish_reason=None))
  306. chunk = ChatCompletionStreamResponse(
  307. id=request_id,
  308. object=chunk_object_type,
  309. created=created_time,
  310. choices=[choice_data],
  311. model=model_name)
  312. if (request.stream_options and
  313. request.stream_options.include_usage):
  314. if (request.stream_options.
  315. continuous_usage_stats):
  316. usage = UsageInfo(
  317. prompt_tokens=num_prompt_tokens,
  318. completion_tokens=0,
  319. total_tokens=num_prompt_tokens)
  320. chunk.usage = usage
  321. else:
  322. chunk.usage = None
  323. data = chunk.model_dump_json(
  324. exclude_unset=True)
  325. yield f"data: {data}\n\n"
  326. first_iteration = False
  327. for output in res.outputs:
  328. i = output.index
  329. if finish_reason_sent[i]:
  330. continue
  331. if request.logprobs and request.top_logprobs is not None:
  332. assert output.logprobs is not None, (
  333. "Did not output logprobs")
  334. logprobs = self._create_chat_logprobs(
  335. token_ids=output.token_ids,
  336. top_logprobs=output.logprobs,
  337. tokenizer=tokenizer,
  338. num_output_top_logprobs=request.top_logprobs,
  339. )
  340. else:
  341. logprobs = None
  342. delta_text = output.text
  343. delta_message: Optional[DeltaMessage]
  344. # handle streaming deltas for tools with named tool_choice
  345. if tool_choice_function_name:
  346. delta_message = DeltaMessage(tool_calls=[
  347. DeltaToolCall(function=DeltaFunctionCall(
  348. name=tool_choice_function_name,
  349. arguments=delta_text),
  350. index=i)
  351. ])
  352. # handle streaming deltas for tools with "auto" tool choice
  353. elif tool_choice_auto:
  354. assert previous_texts is not None
  355. assert all_previous_token_ids is not None
  356. assert tool_parser is not None
  357. # TODO optimize manipulation of these lists
  358. previous_text = previous_texts[i]
  359. previous_token_ids = all_previous_token_ids[i]
  360. current_text = previous_text + delta_text
  361. current_token_ids = previous_token_ids + list(
  362. output.token_ids)
  363. delta_message = (
  364. tool_parser.extract_tool_calls_streaming(
  365. previous_text=previous_text,
  366. current_text=current_text,
  367. delta_text=delta_text,
  368. previous_token_ids=previous_token_ids,
  369. current_token_ids=current_token_ids,
  370. delta_token_ids=output.token_ids))
  371. # update the previous values for the next iteration
  372. previous_texts[i] = current_text
  373. all_previous_token_ids[i] = current_token_ids
  374. # handle streaming just a content delta
  375. else:
  376. delta_message = DeltaMessage(content=delta_text)
  377. # set the previous values for the next iteration
  378. previous_num_tokens[i] += len(output.token_ids)
  379. # if the message delta is None (e.g. because it was a
  380. # "control token" for tool calls or the parser otherwise
  381. # wasn't ready to send a token, then
  382. # get the next token without streaming a chunk
  383. if delta_message is None:
  384. continue
  385. if output.finish_reason is None:
  386. # Send token-by-token response for each request.n
  387. choice_data = ChatCompletionResponseStreamChoice(
  388. index=i,
  389. delta=delta_message,
  390. logprobs=logprobs,
  391. finish_reason=None)
  392. chunk = ChatCompletionStreamResponse(
  393. id=request_id,
  394. object=chunk_object_type,
  395. created=created_time,
  396. choices=[choice_data],
  397. model=model_name)
  398. # handle usage stats if requested & if continuous
  399. if (request.stream_options
  400. and request.stream_options.include_usage):
  401. if request.stream_options.continuous_usage_stats:
  402. completion_tokens = len(output.token_ids)
  403. usage = UsageInfo(
  404. prompt_tokens=num_prompt_tokens,
  405. completion_tokens=completion_tokens,
  406. total_tokens=num_prompt_tokens +
  407. completion_tokens,
  408. )
  409. chunk.usage = usage
  410. else:
  411. chunk.usage = None
  412. data = chunk.model_dump_json(exclude_unset=True)
  413. yield f"data: {data}\n\n"
  414. # if the model is finished generating
  415. else:
  416. # check to make sure we haven't "forgotten" to stream
  417. # any tokens that were generated but previously
  418. # matched by partial json parsing
  419. # only happens if we are NOT using guided decoding
  420. if tool_parser:
  421. index = len(
  422. tool_parser.prev_tool_call_arr) - 1 if len(
  423. tool_parser.prev_tool_call_arr) > 0 else 0
  424. else:
  425. index = 0
  426. if self._should_check_for_unstreamed_tool_arg_tokens(
  427. delta_message, output) and tool_parser:
  428. # get the expected call based on partial JSON
  429. # parsing which "autocompletes" the JSON
  430. expected_call = json.dumps(
  431. tool_parser.prev_tool_call_arr[index].get(
  432. "arguments", {}))
  433. # get what we've streamed so far for arguments
  434. # for the current tool
  435. actual_call = tool_parser.streamed_args_for_tool[
  436. index]
  437. # check to see if there's anything left to stream
  438. remaining_call = expected_call.replace(
  439. actual_call, "", 1)
  440. # set that as a delta message
  441. delta_message = DeltaMessage(tool_calls=[
  442. DeltaToolCall(index=index,
  443. function=DeltaFunctionCall(
  444. arguments=remaining_call).
  445. model_dump(exclude_none=True))
  446. ])
  447. # Send the finish response for each request.n only once
  448. choice_data = ChatCompletionResponseStreamChoice(
  449. index=i,
  450. delta=delta_message,
  451. logprobs=logprobs,
  452. finish_reason=output.finish_reason
  453. if not (tool_parser
  454. and len(tool_parser.prev_tool_call_arr))
  455. else "tool_calls",
  456. stop_reason=output.stop_reason)
  457. chunk = ChatCompletionStreamResponse(
  458. id=request_id,
  459. object=chunk_object_type,
  460. created=created_time,
  461. choices=[choice_data],
  462. model=model_name)
  463. if (request.stream_options
  464. and request.stream_options.include_usage):
  465. if request.stream_options.continuous_usage_stats:
  466. completion_tokens = len(output.token_ids)
  467. usage = UsageInfo(
  468. prompt_tokens=num_prompt_tokens,
  469. completion_tokens=completion_tokens,
  470. total_tokens=num_prompt_tokens +
  471. completion_tokens,
  472. )
  473. chunk.usage = usage
  474. else:
  475. chunk.usage = None
  476. data = chunk.model_dump_json(exclude_unset=True)
  477. yield f"data: {data}\n\n"
  478. finish_reason_sent[i] = True
  479. # once the final token is handled, if stream_options.include_usage
  480. # is sent, send the usage
  481. if (request.stream_options
  482. and request.stream_options.include_usage):
  483. completion_tokens = previous_num_tokens[i]
  484. final_usage = UsageInfo(
  485. prompt_tokens=num_prompt_tokens,
  486. completion_tokens=completion_tokens,
  487. total_tokens=num_prompt_tokens + completion_tokens,
  488. )
  489. final_usage_chunk = ChatCompletionStreamResponse(
  490. id=request_id,
  491. object=chunk_object_type,
  492. created=created_time,
  493. choices=[],
  494. model=model_name,
  495. usage=final_usage)
  496. final_usage_data = (final_usage_chunk.model_dump_json(
  497. exclude_unset=True, exclude_none=True))
  498. yield f"data: {final_usage_data}\n\n"
  499. except ValueError as e:
  500. # TODO: Use an aphrodite-specific Validation Error
  501. logger.error(f"error in chat completion stream generator: {e}")
  502. data = self.create_streaming_error_response(str(e))
  503. yield f"data: {data}\n\n"
  504. # Send the final done message after all response.n are finished
  505. yield "data: [DONE]\n\n"
  506. async def chat_completion_full_generator(
  507. self,
  508. request: ChatCompletionRequest,
  509. result_generator: AsyncIterator[RequestOutput],
  510. request_id: str,
  511. conversation: List[ConversationMessage],
  512. tokenizer: AnyTokenizer,
  513. ) -> Union[ErrorResponse, ChatCompletionResponse]:
  514. model_name = self.base_model_paths[0].name
  515. created_time = int(time.time())
  516. final_res: Optional[RequestOutput] = None
  517. try:
  518. async for res in result_generator:
  519. final_res = res
  520. except asyncio.CancelledError:
  521. return self.create_error_response("Client disconnected")
  522. assert final_res is not None
  523. choices: List[ChatCompletionResponseChoice] = []
  524. role = self.get_chat_request_role(request)
  525. for output in final_res.outputs:
  526. token_ids = output.token_ids
  527. out_logprobs = output.logprobs
  528. if request.logprobs and request.top_logprobs is not None:
  529. assert out_logprobs is not None, "Did not output logprobs"
  530. logprobs = self._create_chat_logprobs(
  531. token_ids=token_ids,
  532. top_logprobs=out_logprobs,
  533. num_output_top_logprobs=request.top_logprobs,
  534. tokenizer=tokenizer,
  535. )
  536. else:
  537. logprobs = None
  538. # by default, tools are not used.
  539. tools_called = False
  540. # if auto tools are not enabled, and a named tool choice using
  541. # outlines is not being used
  542. if (not self.enable_auto_tools
  543. or not self.tool_parser) and not isinstance(
  544. request.tool_choice,
  545. ChatCompletionNamedToolChoiceParam):
  546. message = ChatMessage(role=role, content=output.text)
  547. # if the request uses tools and specified a tool choice
  548. elif request.tool_choice and type(
  549. request.tool_choice) is ChatCompletionNamedToolChoiceParam:
  550. message = ChatMessage(
  551. role=role,
  552. content="",
  553. tool_calls=[
  554. ToolCall(function=FunctionCall(
  555. name=request.tool_choice.function.name,
  556. arguments=output.text))
  557. ])
  558. tools_called = True
  559. # if the request doesn't use tool choice
  560. # OR specifies to not use a tool
  561. elif not request.tool_choice or request.tool_choice == "none":
  562. message = ChatMessage(role=role, content=output.text)
  563. # handle when there are tools and tool choice is auto
  564. elif request.tools and (
  565. request.tool_choice == "auto"
  566. or request.tool_choice is None) and self.enable_auto_tools \
  567. and self.tool_parser:
  568. tool_parser = self.tool_parser(tokenizer)
  569. tool_call_info = tool_parser.extract_tool_calls(output.text)
  570. tools_called = tool_call_info.tools_called
  571. if tool_call_info.tools_called:
  572. message = ChatMessage(role=role,
  573. content=tool_call_info.content,
  574. tool_calls=tool_call_info.tool_calls)
  575. else:
  576. # FOR NOW make it a chat message; we will have to detect
  577. # the type to make it later.
  578. message = ChatMessage(role=role, content=output.text)
  579. # undetermined case that is still important to handle
  580. else:
  581. logger.error(
  582. "Error in chat_completion_full_generator - cannot determine"
  583. " if tools should be extracted. Returning a standard chat "
  584. "completion.")
  585. message = ChatMessage(role=role, content=output.text)
  586. choice_data = ChatCompletionResponseChoice(
  587. index=output.index,
  588. message=message,
  589. logprobs=logprobs,
  590. finish_reason="tool_calls" if tools_called else
  591. output.finish_reason if output.finish_reason else "stop",
  592. stop_reason=output.stop_reason)
  593. choices.append(choice_data)
  594. if request.echo:
  595. last_msg_content = ""
  596. if conversation and "content" in conversation[-1] and conversation[
  597. -1].get("role") == role:
  598. last_msg_content = conversation[-1]["content"] or ""
  599. for choice in choices:
  600. full_message = last_msg_content + (choice.message.content
  601. or "")
  602. choice.message.content = full_message
  603. assert final_res.prompt_token_ids is not None
  604. num_prompt_tokens = len(final_res.prompt_token_ids)
  605. num_generated_tokens = sum(
  606. len(output.token_ids) for output in final_res.outputs)
  607. usage = UsageInfo(
  608. prompt_tokens=num_prompt_tokens,
  609. completion_tokens=num_generated_tokens,
  610. total_tokens=num_prompt_tokens + num_generated_tokens,
  611. )
  612. response = ChatCompletionResponse(
  613. id=request_id,
  614. created=created_time,
  615. model=model_name,
  616. choices=choices,
  617. usage=usage,
  618. prompt_logprobs=final_res.prompt_logprobs,
  619. )
  620. return response
  621. def _get_top_logprobs(
  622. self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
  623. tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]:
  624. return [
  625. ChatCompletionLogProb(token=(token := self._get_decoded_token(
  626. p[1],
  627. p[0],
  628. tokenizer,
  629. return_as_token_id=self.return_tokens_as_token_ids)),
  630. logprob=max(p[1].logprob, -9999.0),
  631. bytes=list(
  632. token.encode("utf-8", errors="replace")))
  633. for i, p in enumerate(logprobs.items())
  634. if top_logprobs and i < top_logprobs
  635. ]
  636. def _create_chat_logprobs(
  637. self,
  638. token_ids: GenericSequence[int],
  639. top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
  640. tokenizer: AnyTokenizer,
  641. num_output_top_logprobs: Optional[int] = None,
  642. ) -> ChatCompletionLogProbs:
  643. """Create OpenAI-style logprobs."""
  644. logprobs_content: List[ChatCompletionLogProbsContent] = []
  645. for i, token_id in enumerate(token_ids):
  646. step_top_logprobs = top_logprobs[i]
  647. if step_top_logprobs is None:
  648. token = tokenizer.decode(token_id)
  649. if self.return_tokens_as_token_ids:
  650. token = f"token_id:{token_id}"
  651. logprobs_content.append(
  652. ChatCompletionLogProbsContent(
  653. token=token,
  654. bytes=list(token.encode("utf-8", errors="replace")),
  655. ))
  656. else:
  657. step_token = step_top_logprobs[token_id]
  658. step_decoded = step_token.decoded_token
  659. logprobs_content.append(
  660. ChatCompletionLogProbsContent(
  661. token=self._get_decoded_token(
  662. step_token,
  663. token_id,
  664. tokenizer,
  665. self.return_tokens_as_token_ids,
  666. ),
  667. logprob=max(step_token.logprob, -9999.0),
  668. bytes=None if step_decoded is None else list(
  669. step_decoded.encode("utf-8", errors="replace")),
  670. top_logprobs=self._get_top_logprobs(
  671. step_top_logprobs,
  672. num_output_top_logprobs,
  673. tokenizer,
  674. ),
  675. ))
  676. return ChatCompletionLogProbs(content=logprobs_content)
  677. def _should_stream_with_auto_tool_parsing(self,
  678. request: ChatCompletionRequest):
  679. """
  680. Utility function to check if streamed tokens should go through the tool
  681. call parser that was configured.
  682. We only want to do this IF user-provided tools are set, a tool parser
  683. is configured, "auto" tool choice is enabled, and the request's tool
  684. choice field indicates that "auto" tool choice should be used.
  685. """
  686. return (request.tools and self.tool_parser and self.enable_auto_tools
  687. and request.tool_choice in ['auto', None])
  688. def _should_check_for_unstreamed_tool_arg_tokens(
  689. self,
  690. delta_message: Optional[DeltaMessage],
  691. output: CompletionOutput,
  692. ) -> bool:
  693. """
  694. Check to see if we should check for unstreamed tool arguments tokens.
  695. This is only applicable when auto tool parsing is enabled, the delta
  696. is a tool call with arguments.
  697. """
  698. # yapf: disable
  699. return bool(
  700. # if there is a delta message that includes tool calls which
  701. # include a function that has arguments
  702. output.finish_reason is not None
  703. and self.enable_auto_tools and self.tool_parser and delta_message
  704. and delta_message.tool_calls and delta_message.tool_calls[0]
  705. and delta_message.tool_calls[0].function
  706. and delta_message.tool_calls[0].function.arguments is not None
  707. )