serving_chat.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804
  1. import asyncio
  2. import json
  3. import time
  4. from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List,
  5. Optional)
  6. from typing import Sequence as GenericSequence
  7. from typing import Union
  8. from fastapi import Request
  9. from loguru import logger
  10. from aphrodite.common.config import ModelConfig
  11. from aphrodite.common.outputs import CompletionOutput, RequestOutput
  12. from aphrodite.common.sequence import Logprob
  13. from aphrodite.common.utils import iterate_with_cancellation, random_uuid
  14. from aphrodite.endpoints.chat_utils import (ConversationMessage,
  15. apply_hf_chat_template,
  16. apply_mistral_chat_template,
  17. load_chat_template,
  18. parse_chat_messages_futures)
  19. from aphrodite.endpoints.logger import RequestLogger
  20. from aphrodite.endpoints.openai.protocol import (
  21. ChatCompletionLogProb, ChatCompletionLogProbs,
  22. ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
  23. ChatCompletionRequest, ChatCompletionResponse,
  24. ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
  25. ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
  26. DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo)
  27. from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
  28. OpenAIServing,
  29. PromptAdapterPath,
  30. TextTokensPrompt)
  31. from aphrodite.endpoints.openai.tool_parsers import (Hermes2ProToolParser,
  32. MistralToolParser,
  33. ToolParser)
  34. from aphrodite.engine.protocol import AsyncEngineClient
  35. from aphrodite.inputs import TokensPrompt
  36. from aphrodite.transformers_utils.tokenizer import (AnyTokenizer,
  37. MistralTokenizer)
  38. class OpenAIServingChat(OpenAIServing):
  39. def __init__(self,
  40. async_engine_client: AsyncEngineClient,
  41. model_config: ModelConfig,
  42. served_model_names: List[str],
  43. response_role: str,
  44. *,
  45. lora_modules: Optional[List[LoRAModulePath]],
  46. prompt_adapters: Optional[List[PromptAdapterPath]],
  47. request_logger: Optional[RequestLogger],
  48. chat_template: Optional[str],
  49. return_tokens_as_token_ids: bool = False,
  50. enable_auto_tools: bool = False,
  51. tool_parser: Optional[str] = None):
  52. super().__init__(async_engine_client=async_engine_client,
  53. model_config=model_config,
  54. served_model_names=served_model_names,
  55. lora_modules=lora_modules,
  56. prompt_adapters=prompt_adapters,
  57. request_logger=request_logger,
  58. return_tokens_as_token_ids=return_tokens_as_token_ids)
  59. self.response_role = response_role
  60. self.use_tool_use_model_template = False
  61. self.chat_template = load_chat_template(chat_template)
  62. # set up tool use
  63. self.enable_auto_tools: bool = enable_auto_tools
  64. if self.enable_auto_tools:
  65. logger.info(
  66. "\"auto\" tool choice has been enabled please note that while"
  67. " the parallel_tool_calls client option is preset for "
  68. "compatibility reasons, it will be ignored.")
  69. self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
  70. if self.enable_auto_tools:
  71. if tool_parser == "mistral":
  72. self.tool_parser = MistralToolParser
  73. elif tool_parser == "hermes":
  74. self.tool_parser = Hermes2ProToolParser
  75. else:
  76. raise TypeError("Error: --enable-auto-tool-choice requires "
  77. "--tool-call-parser")
  78. async def create_chat_completion(
  79. self,
  80. request: ChatCompletionRequest,
  81. raw_request: Optional[Request] = None,
  82. ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
  83. ErrorResponse]:
  84. """Completion API similar to OpenAI's API.
  85. See https://platform.openai.com/docs/api-reference/chat/create
  86. for the API specification. This API mimics the OpenAI
  87. ChatCompletion API.
  88. """
  89. error_check_ret = await self._check_model(request)
  90. if error_check_ret is not None:
  91. logger.error(f"Error with model {error_check_ret}")
  92. return error_check_ret
  93. try:
  94. (
  95. lora_request,
  96. prompt_adapter_request,
  97. ) = self._maybe_get_adapters(request)
  98. model_config = self.model_config
  99. tokenizer = await self.async_engine_client.get_tokenizer(
  100. lora_request)
  101. conversation, mm_data_future = parse_chat_messages_futures(
  102. request.messages, model_config, tokenizer)
  103. tool_dicts = None if request.tools is None else [
  104. tool.model_dump() for tool in request.tools
  105. ]
  106. prompt: Union[str, List[int]]
  107. is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
  108. if is_mistral_tokenizer:
  109. prompt = apply_mistral_chat_template(
  110. tokenizer,
  111. messages=request.messages,
  112. chat_template=request.chat_template or self.chat_template,
  113. add_generation_prompt=request.add_generation_prompt,
  114. tools=tool_dicts,
  115. documents=request.documents,
  116. **(request.chat_template_kwargs or {}),
  117. )
  118. else:
  119. prompt = apply_hf_chat_template(
  120. tokenizer,
  121. conversation=conversation,
  122. chat_template=request.chat_template or self.chat_template,
  123. add_generation_prompt=request.add_generation_prompt,
  124. tools=tool_dicts,
  125. documents=request.documents,
  126. **(request.chat_template_kwargs or {}),
  127. )
  128. except Exception as e:
  129. logger.error(f"Error in applying chat template from request: {e}")
  130. return self.create_error_response(str(e))
  131. try:
  132. mm_data = await mm_data_future
  133. except Exception as e:
  134. logger.error(f"Error in loading multi-modal data: {e}")
  135. return self.create_error_response(str(e))
  136. # validation for OpenAI tools
  137. # tool_choice = "required" is not supported
  138. if request.tool_choice == "required":
  139. return self.create_error_response(
  140. "tool_choice = \"required\" is not supported!")
  141. if not is_mistral_tokenizer and request.tool_choice == "auto" and not (
  142. self.enable_auto_tools and self.tool_parser is not None):
  143. # for hf tokenizers, "auto" tools requires
  144. # --enable-auto-tool-choice and --tool-call-parser
  145. return self.create_error_response(
  146. "\"auto\" tool choice requires "
  147. "--enable-auto-tool-choice and --tool-call-parser to be set")
  148. request_id = f"chat-{random_uuid()}"
  149. try:
  150. guided_decode_logits_processor = (
  151. await self._guided_decode_logits_processor(request, tokenizer))
  152. if isinstance(prompt, str):
  153. prompt_inputs = self._tokenize_prompt_input(
  154. request,
  155. tokenizer,
  156. prompt,
  157. truncate_prompt_tokens=request.truncate_prompt_tokens,
  158. add_special_tokens=request.add_special_tokens,
  159. )
  160. else:
  161. assert isinstance(prompt, list) and isinstance(
  162. prompt[0], int
  163. ), "Prompt has to be either a string or a list of token ids"
  164. prompt_inputs = TextTokensPrompt(
  165. prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)
  166. assert prompt_inputs is not None
  167. sampling_params = request.to_sampling_params(
  168. tokenizer,
  169. guided_decode_logits_processor,
  170. default_max_tokens=self.max_model_len -
  171. len(prompt_inputs["prompt_token_ids"]))
  172. self._log_inputs(request_id,
  173. prompt_inputs,
  174. params=sampling_params,
  175. lora_request=lora_request,
  176. prompt_adapter_request=prompt_adapter_request)
  177. engine_inputs = TokensPrompt(
  178. prompt_token_ids=prompt_inputs["prompt_token_ids"])
  179. if mm_data is not None:
  180. engine_inputs["multi_modal_data"] = mm_data
  181. result_generator = self.async_engine_client.generate(
  182. engine_inputs,
  183. sampling_params,
  184. request_id,
  185. lora_request=lora_request,
  186. prompt_adapter_request=prompt_adapter_request,
  187. )
  188. except ValueError as e:
  189. # TODO: Use an aphrodite-specific Validation Error
  190. return self.create_error_response(str(e))
  191. if raw_request:
  192. result_generator = iterate_with_cancellation(
  193. result_generator, raw_request.is_disconnected)
  194. # Streaming response
  195. if request.stream:
  196. return self.chat_completion_stream_generator(
  197. request, result_generator, request_id, conversation, tokenizer)
  198. try:
  199. return await self.chat_completion_full_generator(
  200. request, result_generator, request_id, conversation, tokenizer)
  201. except ValueError as e:
  202. # TODO: Use an aphrodite-specific Validation Error
  203. return self.create_error_response(str(e))
  204. def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
  205. if request.add_generation_prompt:
  206. return self.response_role
  207. return request.messages[-1]["role"]
  208. async def chat_completion_stream_generator(
  209. self,
  210. request: ChatCompletionRequest,
  211. result_generator: AsyncIterator[RequestOutput],
  212. request_id: str,
  213. conversation: List[ConversationMessage],
  214. tokenizer: AnyTokenizer,
  215. ) -> AsyncGenerator[str, None]:
  216. model_name = self.served_model_names[0]
  217. created_time = int(time.time())
  218. chunk_object_type: Final = "chat.completion.chunk"
  219. first_iteration = True
  220. # Send response for each token for each request.n (index)
  221. num_choices = 1 if request.n is None else request.n
  222. previous_num_tokens = [0] * num_choices
  223. finish_reason_sent = [False] * num_choices
  224. num_prompt_tokens = 0
  225. tool_parser: Optional[ToolParser] = self.tool_parser(
  226. tokenizer) if self.tool_parser else None
  227. if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
  228. tool_choice_function_name = request.tool_choice.function.name
  229. else:
  230. tool_choice_function_name = None
  231. # Determine whether tools are in use with "auto" tool choice
  232. tool_choice_auto = (
  233. not tool_choice_function_name
  234. and self._should_stream_with_auto_tool_parsing(request))
  235. all_previous_token_ids: Optional[List[List[int]]]
  236. if tool_choice_auto:
  237. # These are only required in "auto" tool choice case
  238. previous_texts = [""] * num_choices
  239. all_previous_token_ids = [[]] * num_choices
  240. else:
  241. previous_texts, all_previous_token_ids = None, None
  242. try:
  243. async for res in result_generator:
  244. if res.prompt_token_ids is not None:
  245. num_prompt_tokens = len(res.prompt_token_ids)
  246. # We need to do it here, because if there are exceptions in
  247. # the result_generator, it needs to be sent as the FIRST
  248. # response (by the try...catch).
  249. if first_iteration:
  250. # Send first response for each request.n (index) with
  251. # the role
  252. role = self.get_chat_request_role(request)
  253. # NOTE num_choices defaults to 1 so this usually executes
  254. # once per request
  255. for i in range(num_choices):
  256. choice_data = ChatCompletionResponseStreamChoice(
  257. index=i,
  258. delta=DeltaMessage(
  259. role=role,
  260. content="",
  261. ),
  262. logprobs=None,
  263. finish_reason=None)
  264. chunk = ChatCompletionStreamResponse(
  265. id=request_id,
  266. object=chunk_object_type,
  267. created=created_time,
  268. choices=[choice_data],
  269. model=model_name)
  270. # if usage should be included
  271. if (request.stream_options
  272. and request.stream_options.include_usage):
  273. # if continuous usage stats are requested, add it
  274. if request.stream_options.continuous_usage_stats:
  275. usage = UsageInfo(
  276. prompt_tokens=num_prompt_tokens,
  277. completion_tokens=0,
  278. total_tokens=num_prompt_tokens)
  279. chunk.usage = usage
  280. # otherwise don't
  281. else:
  282. chunk.usage = None
  283. data = chunk.model_dump_json(exclude_unset=True)
  284. yield f"data: {data}\n\n"
  285. # Send response to echo the input portion of the
  286. # last message
  287. if request.echo:
  288. last_msg_content: str = ""
  289. if conversation and "content" in conversation[
  290. -1] and conversation[-1].get("role") == role:
  291. last_msg_content = conversation[-1]["content"] or ""
  292. if last_msg_content:
  293. for i in range(num_choices):
  294. choice_data = (
  295. ChatCompletionResponseStreamChoice(
  296. index=i,
  297. delta=DeltaMessage(
  298. content=last_msg_content),
  299. logprobs=None,
  300. finish_reason=None))
  301. chunk = ChatCompletionStreamResponse(
  302. id=request_id,
  303. object=chunk_object_type,
  304. created=created_time,
  305. choices=[choice_data],
  306. model=model_name)
  307. if (request.stream_options and
  308. request.stream_options.include_usage):
  309. if (request.stream_options.
  310. continuous_usage_stats):
  311. usage = UsageInfo(
  312. prompt_tokens=num_prompt_tokens,
  313. completion_tokens=0,
  314. total_tokens=num_prompt_tokens)
  315. chunk.usage = usage
  316. else:
  317. chunk.usage = None
  318. data = chunk.model_dump_json(
  319. exclude_unset=True)
  320. yield f"data: {data}\n\n"
  321. first_iteration = False
  322. for output in res.outputs:
  323. i = output.index
  324. if finish_reason_sent[i]:
  325. continue
  326. if request.logprobs and request.top_logprobs is not None:
  327. assert output.logprobs is not None, (
  328. "Did not output logprobs")
  329. logprobs = self._create_chat_logprobs(
  330. token_ids=output.token_ids,
  331. top_logprobs=output.logprobs,
  332. tokenizer=tokenizer,
  333. num_output_top_logprobs=request.top_logprobs,
  334. )
  335. else:
  336. logprobs = None
  337. delta_text = output.text
  338. delta_message: Optional[DeltaMessage]
  339. # handle streaming deltas for tools with named tool_choice
  340. if tool_choice_function_name:
  341. delta_message = DeltaMessage(tool_calls=[
  342. DeltaToolCall(function=DeltaFunctionCall(
  343. name=tool_choice_function_name,
  344. arguments=delta_text),
  345. index=i)
  346. ])
  347. # handle streaming deltas for tools with "auto" tool choice
  348. elif tool_choice_auto:
  349. assert previous_texts is not None
  350. assert all_previous_token_ids is not None
  351. assert tool_parser is not None
  352. # TODO optimize manipulation of these lists
  353. previous_text = previous_texts[i]
  354. previous_token_ids = all_previous_token_ids[i]
  355. current_text = previous_text + delta_text
  356. current_token_ids = previous_token_ids + list(
  357. output.token_ids)
  358. delta_message = (
  359. tool_parser.extract_tool_calls_streaming(
  360. previous_text=previous_text,
  361. current_text=current_text,
  362. delta_text=delta_text,
  363. previous_token_ids=previous_token_ids,
  364. current_token_ids=current_token_ids,
  365. delta_token_ids=output.token_ids))
  366. # update the previous values for the next iteration
  367. previous_texts[i] = current_text
  368. all_previous_token_ids[i] = current_token_ids
  369. # handle streaming just a content delta
  370. else:
  371. delta_message = DeltaMessage(content=delta_text)
  372. # set the previous values for the next iteration
  373. previous_num_tokens[i] += len(output.token_ids)
  374. # if the message delta is None (e.g. because it was a
  375. # "control token" for tool calls or the parser otherwise
  376. # wasn't ready to send a token, then
  377. # get the next token without streaming a chunk
  378. if delta_message is None:
  379. continue
  380. if output.finish_reason is None:
  381. # Send token-by-token response for each request.n
  382. choice_data = ChatCompletionResponseStreamChoice(
  383. index=i,
  384. delta=delta_message,
  385. logprobs=logprobs,
  386. finish_reason=None)
  387. chunk = ChatCompletionStreamResponse(
  388. id=request_id,
  389. object=chunk_object_type,
  390. created=created_time,
  391. choices=[choice_data],
  392. model=model_name)
  393. # handle usage stats if requested & if continuous
  394. if (request.stream_options
  395. and request.stream_options.include_usage):
  396. if request.stream_options.continuous_usage_stats:
  397. completion_tokens = len(output.token_ids)
  398. usage = UsageInfo(
  399. prompt_tokens=num_prompt_tokens,
  400. completion_tokens=completion_tokens,
  401. total_tokens=num_prompt_tokens +
  402. completion_tokens,
  403. )
  404. chunk.usage = usage
  405. else:
  406. chunk.usage = None
  407. data = chunk.model_dump_json(exclude_unset=True)
  408. yield f"data: {data}\n\n"
  409. # if the model is finished generating
  410. else:
  411. # check to make sure we haven't "forgotten" to stream
  412. # any tokens that were generated but previously
  413. # matched by partial json parsing
  414. # only happens if we are NOT using guided decoding
  415. if tool_parser:
  416. index = len(
  417. tool_parser.prev_tool_call_arr) - 1 if len(
  418. tool_parser.prev_tool_call_arr) > 0 else 0
  419. else:
  420. index = 0
  421. if self._should_check_for_unstreamed_tool_arg_tokens(
  422. delta_message, output) and tool_parser:
  423. # get the expected call based on partial JSON
  424. # parsing which "autocompletes" the JSON
  425. expected_call = json.dumps(
  426. tool_parser.prev_tool_call_arr[index].get(
  427. "arguments", {}))
  428. # get what we've streamed so far for arguments
  429. # for the current tool
  430. actual_call = tool_parser.streamed_args_for_tool[
  431. index]
  432. # check to see if there's anything left to stream
  433. remaining_call = expected_call.replace(
  434. actual_call, "", 1)
  435. # set that as a delta message
  436. delta_message = DeltaMessage(tool_calls=[
  437. DeltaToolCall(index=index,
  438. function=DeltaFunctionCall(
  439. arguments=remaining_call).
  440. model_dump(exclude_none=True))
  441. ])
  442. # Send the finish response for each request.n only once
  443. choice_data = ChatCompletionResponseStreamChoice(
  444. index=i,
  445. delta=delta_message,
  446. logprobs=logprobs,
  447. finish_reason=output.finish_reason
  448. if not (tool_parser
  449. and len(tool_parser.prev_tool_call_arr))
  450. else "tool_calls",
  451. stop_reason=output.stop_reason)
  452. chunk = ChatCompletionStreamResponse(
  453. id=request_id,
  454. object=chunk_object_type,
  455. created=created_time,
  456. choices=[choice_data],
  457. model=model_name)
  458. if (request.stream_options
  459. and request.stream_options.include_usage):
  460. if request.stream_options.continuous_usage_stats:
  461. completion_tokens = len(output.token_ids)
  462. usage = UsageInfo(
  463. prompt_tokens=num_prompt_tokens,
  464. completion_tokens=completion_tokens,
  465. total_tokens=num_prompt_tokens +
  466. completion_tokens,
  467. )
  468. chunk.usage = usage
  469. else:
  470. chunk.usage = None
  471. data = chunk.model_dump_json(exclude_unset=True)
  472. yield f"data: {data}\n\n"
  473. finish_reason_sent[i] = True
  474. # once the final token is handled, if stream_options.include_usage
  475. # is sent, send the usage
  476. if (request.stream_options
  477. and request.stream_options.include_usage):
  478. completion_tokens = previous_num_tokens[i]
  479. final_usage = UsageInfo(
  480. prompt_tokens=num_prompt_tokens,
  481. completion_tokens=completion_tokens,
  482. total_tokens=num_prompt_tokens + completion_tokens,
  483. )
  484. final_usage_chunk = ChatCompletionStreamResponse(
  485. id=request_id,
  486. object=chunk_object_type,
  487. created=created_time,
  488. choices=[],
  489. model=model_name,
  490. usage=final_usage)
  491. final_usage_data = (final_usage_chunk.model_dump_json(
  492. exclude_unset=True, exclude_none=True))
  493. yield f"data: {final_usage_data}\n\n"
  494. except ValueError as e:
  495. # TODO: Use an aphrodite-specific Validation Error
  496. logger.error(f"error in chat completion stream generator: {e}")
  497. data = self.create_streaming_error_response(str(e))
  498. yield f"data: {data}\n\n"
  499. # Send the final done message after all response.n are finished
  500. yield "data: [DONE]\n\n"
  501. async def chat_completion_full_generator(
  502. self,
  503. request: ChatCompletionRequest,
  504. result_generator: AsyncIterator[RequestOutput],
  505. request_id: str,
  506. conversation: List[ConversationMessage],
  507. tokenizer: AnyTokenizer,
  508. ) -> Union[ErrorResponse, ChatCompletionResponse]:
  509. model_name = self.served_model_names[0]
  510. created_time = int(time.time())
  511. final_res: Optional[RequestOutput] = None
  512. try:
  513. async for res in result_generator:
  514. final_res = res
  515. except asyncio.CancelledError:
  516. return self.create_error_response("Client disconnected")
  517. assert final_res is not None
  518. choices: List[ChatCompletionResponseChoice] = []
  519. role = self.get_chat_request_role(request)
  520. for output in final_res.outputs:
  521. token_ids = output.token_ids
  522. out_logprobs = output.logprobs
  523. if request.logprobs and request.top_logprobs is not None:
  524. assert out_logprobs is not None, "Did not output logprobs"
  525. logprobs = self._create_chat_logprobs(
  526. token_ids=token_ids,
  527. top_logprobs=out_logprobs,
  528. num_output_top_logprobs=request.top_logprobs,
  529. tokenizer=tokenizer,
  530. )
  531. else:
  532. logprobs = None
  533. # by default, tools are not used.
  534. tools_called = False
  535. # if auto tools are not enabled, and a named tool choice using
  536. # outlines is not being used
  537. if (not self.enable_auto_tools
  538. or not self.tool_parser) and not isinstance(
  539. request.tool_choice,
  540. ChatCompletionNamedToolChoiceParam):
  541. message = ChatMessage(role=role, content=output.text)
  542. # if the request uses tools and specified a tool choice
  543. elif request.tool_choice and type(
  544. request.tool_choice) is ChatCompletionNamedToolChoiceParam:
  545. message = ChatMessage(
  546. role=role,
  547. content="",
  548. tool_calls=[
  549. ToolCall(function=FunctionCall(
  550. name=request.tool_choice.function.name,
  551. arguments=output.text))
  552. ])
  553. tools_called = True
  554. # if the request doesn't use tool choice
  555. # OR specifies to not use a tool
  556. elif not request.tool_choice or request.tool_choice == "none":
  557. message = ChatMessage(role=role, content=output.text)
  558. # handle when there are tools and tool choice is auto
  559. elif request.tools and (
  560. request.tool_choice == "auto"
  561. or request.tool_choice is None) and self.enable_auto_tools \
  562. and self.tool_parser:
  563. tool_parser = self.tool_parser(tokenizer)
  564. tool_call_info = tool_parser.extract_tool_calls(output.text)
  565. tools_called = tool_call_info.tools_called
  566. if tool_call_info.tools_called:
  567. message = ChatMessage(role=role,
  568. content=tool_call_info.content,
  569. tool_calls=tool_call_info.tool_calls)
  570. else:
  571. # FOR NOW make it a chat message; we will have to detect
  572. # the type to make it later.
  573. message = ChatMessage(role=role, content=output.text)
  574. # undetermined case that is still important to handle
  575. else:
  576. logger.error(
  577. "Error in chat_completion_full_generator - cannot determine"
  578. " if tools should be extracted. Returning a standard chat "
  579. "completion.")
  580. message = ChatMessage(role=role, content=output.text)
  581. choice_data = ChatCompletionResponseChoice(
  582. index=output.index,
  583. message=message,
  584. logprobs=logprobs,
  585. finish_reason="tool_calls" if tools_called else
  586. output.finish_reason if output.finish_reason else "stop",
  587. stop_reason=output.stop_reason)
  588. choices.append(choice_data)
  589. if request.echo:
  590. last_msg_content = ""
  591. if conversation and "content" in conversation[-1] and conversation[
  592. -1].get("role") == role:
  593. last_msg_content = conversation[-1]["content"] or ""
  594. for choice in choices:
  595. full_message = last_msg_content + (choice.message.content
  596. or "")
  597. choice.message.content = full_message
  598. assert final_res.prompt_token_ids is not None
  599. num_prompt_tokens = len(final_res.prompt_token_ids)
  600. num_generated_tokens = sum(
  601. len(output.token_ids) for output in final_res.outputs)
  602. usage = UsageInfo(
  603. prompt_tokens=num_prompt_tokens,
  604. completion_tokens=num_generated_tokens,
  605. total_tokens=num_prompt_tokens + num_generated_tokens,
  606. )
  607. response = ChatCompletionResponse(
  608. id=request_id,
  609. created=created_time,
  610. model=model_name,
  611. choices=choices,
  612. usage=usage,
  613. prompt_logprobs=final_res.prompt_logprobs,
  614. )
  615. return response
  616. def _get_top_logprobs(
  617. self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
  618. tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]:
  619. return [
  620. ChatCompletionLogProb(token=(token := self._get_decoded_token(
  621. p[1],
  622. p[0],
  623. tokenizer,
  624. return_as_token_id=self.return_tokens_as_token_ids)),
  625. logprob=max(p[1].logprob, -9999.0),
  626. bytes=list(
  627. token.encode("utf-8", errors="replace")))
  628. for i, p in enumerate(logprobs.items())
  629. if top_logprobs and i < top_logprobs
  630. ]
  631. def _create_chat_logprobs(
  632. self,
  633. token_ids: GenericSequence[int],
  634. top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
  635. tokenizer: AnyTokenizer,
  636. num_output_top_logprobs: Optional[int] = None,
  637. ) -> ChatCompletionLogProbs:
  638. """Create OpenAI-style logprobs."""
  639. logprobs_content: List[ChatCompletionLogProbsContent] = []
  640. for i, token_id in enumerate(token_ids):
  641. step_top_logprobs = top_logprobs[i]
  642. if step_top_logprobs is None:
  643. token = tokenizer.decode(token_id)
  644. if self.return_tokens_as_token_ids:
  645. token = f"token_id:{token_id}"
  646. logprobs_content.append(
  647. ChatCompletionLogProbsContent(
  648. token=token,
  649. bytes=list(token.encode("utf-8", errors="replace")),
  650. ))
  651. else:
  652. step_token = step_top_logprobs[token_id]
  653. step_decoded = step_token.decoded_token
  654. logprobs_content.append(
  655. ChatCompletionLogProbsContent(
  656. token=self._get_decoded_token(
  657. step_token,
  658. token_id,
  659. tokenizer,
  660. self.return_tokens_as_token_ids,
  661. ),
  662. logprob=max(step_token.logprob, -9999.0),
  663. bytes=None if step_decoded is None else list(
  664. step_decoded.encode("utf-8", errors="replace")),
  665. top_logprobs=self._get_top_logprobs(
  666. step_top_logprobs,
  667. num_output_top_logprobs,
  668. tokenizer,
  669. ),
  670. ))
  671. return ChatCompletionLogProbs(content=logprobs_content)
  672. def _should_stream_with_auto_tool_parsing(self,
  673. request: ChatCompletionRequest):
  674. """
  675. Utility function to check if streamed tokens should go through the tool
  676. call parser that was configured.
  677. We only want to do this IF user-provided tools are set, a tool parser
  678. is configured, "auto" tool choice is enabled, and the request's tool
  679. choice field indicates that "auto" tool choice should be used.
  680. """
  681. return (request.tools and self.tool_parser and self.enable_auto_tools
  682. and request.tool_choice in ['auto', None])
  683. def _should_check_for_unstreamed_tool_arg_tokens(
  684. self,
  685. delta_message: Optional[DeltaMessage],
  686. output: CompletionOutput,
  687. ) -> bool:
  688. """
  689. Check to see if we should check for unstreamed tool arguments tokens.
  690. This is only applicable when auto tool parsing is enabled, the delta
  691. is a tool call with arguments.
  692. """
  693. # yapf: disable
  694. return bool(
  695. # if there is a delta message that includes tool calls which
  696. # include a function that has arguments
  697. output.finish_reason is not None
  698. and self.enable_auto_tools and self.tool_parser and delta_message
  699. and delta_message.tool_calls and delta_message.tool_calls[0]
  700. and delta_message.tool_calls[0].function
  701. and delta_message.tool_calls[0].function.arguments is not None
  702. )