AlpinDale 6 hónapja
szülő
commit
561588d97d

+ 83 - 24
aphrodite/endpoints/chat_utils.py

@@ -1,26 +1,30 @@
 import codecs
+import json
 import tempfile
 from dataclasses import dataclass
-from functools import lru_cache
+from functools import lru_cache, partial
 from pathlib import Path
-from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple,
-                    Union, cast)
+from typing import (Any, Awaitable, Dict, Iterable, List, Literal, Optional,
+                    Tuple, Union, cast)
 
 import requests
 from loguru import logger
 # yapf conflicts with isort for this block
 # yapf: disable
-from openai.types.chat import ChatCompletionContentPartImageParam
+from openai.types.chat import (ChatCompletionAssistantMessageParam,
+                               ChatCompletionContentPartImageParam)
 from openai.types.chat import (
     ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
-from openai.types.chat import ChatCompletionContentPartTextParam
+from openai.types.chat import (ChatCompletionContentPartRefusalParam,
+                               ChatCompletionContentPartTextParam)
 from openai.types.chat import (
     ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
+from openai.types.chat import (ChatCompletionMessageToolCallParam,
+                               ChatCompletionToolMessageParam)
 # yapf: enable
 # pydantic needs the TypedDict from typing_extensions
 from pydantic import ConfigDict
-from transformers import PreTrainedTokenizer
-from typing_extensions import Required, TypedDict
+from typing_extensions import Required, TypeAlias, TypedDict
 
 from aphrodite.common.config import ModelConfig
 from aphrodite.multimodal import MultiModalDataDict
@@ -50,9 +54,10 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
     """The type of the content part."""
 
 
-ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam,
-                                       ChatCompletionContentPartAudioParam,
-                                       CustomChatCompletionContentPartParam]
+ChatCompletionContentPartParam: TypeAlias = Union[
+    OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
+    ChatCompletionContentPartRefusalParam,
+    CustomChatCompletionContentPartParam]
 
 
 class CustomChatCompletionMessageParam(TypedDict, total=False):
@@ -69,15 +74,33 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
     same role.
     """
 
+    tool_call_id: Optional[str]
+    """Tool call that this message is responding to."""
+
+    tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
+    """The tool calls generated by the model, such as function calls."""
+
 
 ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
                                    CustomChatCompletionMessageParam]
 
 
 # TODO: Make fields ReadOnly once mypy supports it
-class ConversationMessage(TypedDict):
-    role: str
-    content: str
+class ConversationMessage(TypedDict, total=False):
+    role: Required[str]
+    """The role of the message's author."""
+
+    content: Optional[str]
+    """The contents of the message"""
+
+    tool_call_id: Optional[str]
+    """Tool call that this message is responding to."""
+
+    name: Optional[str]
+    """The name of the function to call"""
+
+    tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
+    """The tool calls generated by the model, such as function calls."""
 
 
 @dataclass(frozen=True)
@@ -120,7 +143,7 @@ def load_chat_template(
 
 
 @lru_cache(maxsize=None)
-def _mm_token_str(model_config: ModelConfig, tokenizer: PreTrainedTokenizer,
+def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer,
                   modality: Literal["image", "audio"]) -> Optional[str]:
     # TODO: Let user specify how to insert image tokens into prompt
     # (similar to chat template)
@@ -161,7 +184,7 @@ def _parse_chat_message_content_parts(
     role: str,
     parts: Iterable[ChatCompletionContentPartParam],
     model_config: ModelConfig,
-    tokenizer: PreTrainedTokenizer,
+    tokenizer: AnyTokenizer,
 ) -> ChatMessageParseResult:
     texts: List[str] = []
     mm_futures: List[Awaitable[MultiModalDataDict]] = []
@@ -170,7 +193,8 @@ def _parse_chat_message_content_parts(
     for part in parts:
         part_type = part["type"]
         if part_type == "text":
-            text = cast(ChatCompletionContentPartTextParam, part)["text"]
+            text = partial(cast(ChatCompletionContentPartTextParam, part)
+                           )["text"]
             texts.append(text)
         elif part_type == "image_url":
             modality = "image"
@@ -222,28 +246,49 @@ def _parse_chat_message_content_parts(
     return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
 
 
+# No need to validate using Pydantic again
+_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
+_ToolParser = partial(cast, ChatCompletionToolMessageParam)
+
+
 def _parse_chat_message_content(
     message: ChatCompletionMessageParam,
     model_config: ModelConfig,
-    tokenizer: PreTrainedTokenizer,
+    tokenizer: AnyTokenizer,
 ) -> ChatMessageParseResult:
     role = message["role"]
     content = message.get("content")
 
     if content is None:
-        return ChatMessageParseResult(messages=[], mm_futures=[])
-    if isinstance(content, str):
-        messages = [ConversationMessage(role=role, content=content)]
-        return ChatMessageParseResult(messages=messages, mm_futures=[])
+        content = []
+    elif isinstance(content, str):
+        content = [
+            ChatCompletionContentPartTextParam(type="text", text=content)
+        ]
+
+    result = _parse_chat_message_content_parts(role, content,
+                                               model_config, tokenizer)
+    for result_msg in result:
+        if role == 'assistant':
+            parsed_msg = _AssistantParser(message)
 
-    return _parse_chat_message_content_parts(role, content, model_config,
-                                             tokenizer)
+            if "tool_calls" in parsed_msg:
+                result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
+        elif role == "tool":
+            parsed_msg = _ToolParser(message)
+            if "tool_call_id" in parsed_msg:
+                result_msg["tool_call_id"] = parsed_msg["tool_call_id"]
+
+        if "name" in message and isinstance(message["name"], str):
+            result_msg["name"] = message["name"]
+
+    return result
 
 
 def parse_chat_messages(
     messages: List[ChatCompletionMessageParam],
     model_config: ModelConfig,
-    tokenizer: PreTrainedTokenizer,
+    tokenizer: AnyTokenizer,
 ) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
     conversation: List[ConversationMessage] = []
     mm_futures: List[Awaitable[MultiModalDataDict]] = []
@@ -272,6 +317,20 @@ def apply_chat_template(
             "allowed, so you must provide a chat template if the tokenizer "
             "does not define one.")
 
+    # per the Transformers docs & maintainers, tool call arguments in
+    # assistant-role messages with tool_calls need to be dicts not JSON str -
+    # this is how tool-use chat templates will expect them moving forwards
+    # so, for messages that have tool_calls, parse the string (which we get
+    # from openAI format) to dict
+    for message in conversation:
+        if (message["role"] == "assistant" and "tool_calls" in message
+                and isinstance(message["tool_calls"], list)):
+
+            for i in range(len(message["tool_calls"])):
+                args: str = message["tool_calls"][i]["function"]["arguments"]
+                parsed_args: Dict = json.loads(args)
+                message["tool_calls"][i]["function"]["arguments"] = parsed_args
+
     prompt = tokenizer.apply_chat_template(
         conversation=conversation,
         chat_template=chat_template,

+ 5 - 0
aphrodite/endpoints/openai/tool_parsers/__init__.py

@@ -0,0 +1,5 @@
+from .abstract_tool_parser import ToolParser
+from .hermes_tool_parser import Hermes2ProToolParser
+from .mistral_tool_parser import MistralToolParser
+
+__all__ = ["ToolParser", "Hermes2ProToolParser", "MistralToolParser"]

+ 344 - 0
aphrodite/endpoints/openai/tool_parsers/hermes_tool_parser.py

@@ -0,0 +1,344 @@
+import json
+import re
+from typing import Dict, List, Sequence, Union
+
+import partial_json_parser
+from loguru import logger
+from partial_json_parser.core.options import Allow
+
+from aphrodite.endpoints.openai.protocol import (DeltaFunctionCall,
+                                                 DeltaMessage, DeltaToolCall,
+                                                 ExtractedToolCallInformation,
+                                                 FunctionCall,
+                                                 InitialDeltaToolCall,
+                                                 ToolCall)
+from aphrodite.endpoints.openai.tool_parsers.abstract_tool_parser import (
+    ToolParser)
+from aphrodite.endpoints.openai.tool_parsers.utils import (
+    extract_intermediate_diff)
+from aphrodite.transformers_utils.tokenizer import (AnyTokenizer,
+                                                    MistralTokenizer)
+
+
+class Hermes2ProToolParser(ToolParser):
+
+    def __init__(self, tokenizer: AnyTokenizer):
+        super().__init__(tokenizer)
+
+        if isinstance(self.model_tokenizer, MistralTokenizer):
+            logger.error(
+                "Detected Mistral tokenizer when using a Hermes model")
+            self.model_tokenizer = self.model_tokenizer.tokenizer
+
+        self.current_tool_name_sent: bool = False
+        self.prev_tool_call_arr: List[Dict] = []
+        self.current_tool_id: int = -1
+        self.current_tool_name_sent = False
+        self.current_tool_initial_sent: bool = False
+        self.streamed_args_for_tool: List[str] = [
+        ]  # map what has been streamed for each tool so far to a list
+
+        self.tool_call_start_token: str = "<tool_call>"
+        self.tool_call_end_token: str = "</tool_call>"
+
+        self.tool_call_regex = re.compile(
+            r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL)
+        self.scratch_pad_regex = re.compile(
+            r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL)
+
+        if not self.model_tokenizer:
+            raise ValueError(
+                "The model tokenizer must be passed to the ToolParser "
+                "constructor during construction.")
+        self.tool_call_start_token_id: int = self.model_tokenizer.vocab[
+            self.tool_call_start_token]
+        self.tool_call_end_token_id: int = self.model_tokenizer.vocab[
+            self.tool_call_end_token]
+        if not self.tool_call_start_token_id or not self.tool_call_end_token_id:
+            raise RuntimeError(
+                "Hermes 2 Pro Tool parser could not locate tool call start/end "
+                "tokens in the tokenizer!")
+
+    def extract_tool_calls(self,
+                           model_output: str) -> ExtractedToolCallInformation:
+
+        # sanity check; avoid unnecessary processing
+        if self.tool_call_start_token not in model_output:
+            return ExtractedToolCallInformation(tools_called=False,
+                                                tool_calls=[],
+                                                content=model_output)
+
+        else:
+
+            try:
+                # there are two possible captures - between tags, or between a
+                # tag and end-of-string so the result of
+                # findall is an array of tuples where one is a function call and
+                # the other is None
+                function_call_tuples = (
+                    self.tool_call_regex.findall(model_output))
+
+                # load the JSON, and then use it to build the Function and
+                # Tool Call
+                raw_function_calls = [
+                    json.loads(match[0] if match[0] else match[1])
+                    for match in function_call_tuples
+                ]
+                tool_calls = [
+                    ToolCall(
+                        type="function",
+                        function=FunctionCall(
+                            name=function_call["name"],
+                            # function call args are JSON but as a string
+                            arguments=json.dumps(function_call["arguments"])))
+                    for function_call in raw_function_calls
+                ]
+
+                content = model_output[:model_output.
+                                       find(self.tool_call_start_token)]
+                return ExtractedToolCallInformation(
+                    tools_called=True,
+                    tool_calls=tool_calls,
+                    content=content if content else None)
+
+            except Exception as e:
+                logger.error(
+                    f"Error in extracting tool call from response {e}")
+                return ExtractedToolCallInformation(tools_called=False,
+                                                    tool_calls=[],
+                                                    content=model_output)
+
+    def extract_tool_calls_streaming(
+        self,
+        previous_text: str,
+        current_text: str,
+        delta_text: str,
+        previous_token_ids: Sequence[int],
+        current_token_ids: Sequence[int],
+        delta_token_ids: Sequence[int],
+    ) -> Union[DeltaMessage, None]:
+
+        logger.debug(f"delta_text: {delta_text}")
+        logger.debug(f"delta_token_ids: {delta_token_ids}")
+        # check to see if we should be streaming a tool call - is there a
+        if self.tool_call_start_token_id not in current_token_ids:
+            logger.debug("No tool call tokens found!")
+            return DeltaMessage(content=delta_text)
+
+        try:
+
+            # figure out where we are in the parsing by counting tool call
+            # start & end tags
+            prev_tool_start_count = previous_token_ids.count(
+                self.tool_call_start_token_id)
+            prev_tool_end_count = previous_token_ids.count(
+                self.tool_call_end_token_id)
+            cur_tool_start_count = current_token_ids.count(
+                self.tool_call_start_token_id)
+            cur_tool_end_count = current_token_ids.count(
+                self.tool_call_end_token_id)
+
+            # case: if we're generating text, OR rounding out a tool call
+            if (cur_tool_start_count == cur_tool_end_count
+                    and prev_tool_end_count == cur_tool_end_count):
+                logger.debug("Generating text content! skipping tool parsing.")
+                if delta_text != self.tool_call_end_token:
+                    return DeltaMessage(content=delta_text)
+
+            # case: if tool open & close tag counts don't match, we're doing
+            # imaginary "else" block here
+            # something with tools with this diff.
+            # flags for partial JSON parting. exported constants from
+            # "Allow" are handled via BIT MASK
+            flags = Allow.ALL if self.current_tool_name_sent \
+                else Allow.ALL & ~Allow.STR
+
+            # case -- we're starting a new tool call
+            if (cur_tool_start_count > cur_tool_end_count
+                    and cur_tool_start_count > prev_tool_start_count):
+                if len(delta_token_ids) > 1:
+                    tool_call_portion = current_text.split(
+                        self.tool_call_start_token)[-1]
+                else:
+                    tool_call_portion = None
+                    delta = None
+
+                text_portion = None
+
+                # set cursors and state appropriately
+                self.current_tool_id += 1
+                self.current_tool_name_sent = False
+                self.current_tool_initial_sent = False
+                self.streamed_args_for_tool.append("")
+                logger.debug(f"Starting on a new tool {self.current_tool_id}")
+
+            # case -- we're updating an existing tool call
+            elif (cur_tool_start_count > cur_tool_end_count
+                  and cur_tool_start_count == prev_tool_start_count):
+
+                # get the portion of the text that's the tool call
+                tool_call_portion = current_text.split(
+                    self.tool_call_start_token)[-1]
+                text_portion = None
+
+            # case -- the current tool call is being closed.
+            elif (cur_tool_start_count == cur_tool_end_count
+                  and cur_tool_end_count > prev_tool_end_count):
+                diff = self.prev_tool_call_arr[self.current_tool_id].get(
+                    "arguments")
+                if diff:
+                    diff = json.dumps(diff).replace(
+                        self.streamed_args_for_tool[self.current_tool_id], "")
+                    logger.debug(
+                        "Finishing tool and found diff that had not "
+                        f"been streamed yet: {diff}")
+                    self.streamed_args_for_tool[self.current_tool_id] \
+                        += diff
+                    return DeltaMessage(tool_calls=[
+                        DeltaToolCall(index=self.current_tool_id,
+                                      function=DeltaFunctionCall(
+                                          arguments=diff).model_dump(
+                                              exclude_none=True))
+                    ])
+
+            # case -- otherwise we're just generating text
+            else:
+                text = delta_text.replace(self.tool_call_start_token, "")
+                text = text.replace(self.tool_call_end_token, "")
+                delta = DeltaMessage(tool_calls=[], content=text)
+                return delta
+
+            try:
+
+                current_tool_call = partial_json_parser.loads(
+                    tool_call_portion or "{}",
+                    flags) if tool_call_portion else None
+                logger.debug(f"Parsed tool call {current_tool_call}")
+            except partial_json_parser.core.exceptions.MalformedJSON:
+                logger.debug('not enough tokens to parse into JSON yet')
+                return None
+
+            # case - we haven't sent the initial delta with the tool call ID
+            #   (it will be sent)
+            if not self.current_tool_initial_sent:
+                self.current_tool_initial_sent = True
+                return DeltaMessage(tool_calls=[
+                    InitialDeltaToolCall(
+                        index=self.current_tool_id).model_dump(
+                            exclude_none=True)
+                ])
+
+            # case - we haven't sent the tool name yet. If it's available, send
+            #   it. otherwise, wait until it's available.
+            elif not self.current_tool_name_sent:
+                function_name: Union[str, None] = current_tool_call.get("name")
+                if function_name:
+                    self.current_tool_name_sent = True
+                    return DeltaMessage(tool_calls=[
+                        DeltaToolCall(index=self.current_tool_id,
+                                      function=DeltaFunctionCall(
+                                          name=function_name).model_dump(
+                                              exclude_none=True))
+                    ])
+                else:
+                    return None
+            # case -- otherwise, send the tool call delta
+
+            # if the tool call portion is None, send the delta as text
+            if tool_call_portion is None:
+                # if there's text but not tool calls, send that -
+                # otherwise None to skip chunk
+                delta = DeltaMessage(content=delta_text) \
+                    if text_portion is not None else None
+                return delta
+
+            # now, the nitty-gritty of tool calls
+            # now we have the portion to parse as tool call.
+
+            logger.debug("Trying to parse current tool call with ID "
+                         f"{self.current_tool_id}")
+
+            # if we're starting a new tool call, push an empty object in as
+            #   a placeholder for the arguments
+            if len(self.prev_tool_call_arr) <= self.current_tool_id:
+                self.prev_tool_call_arr.append({})
+
+            # main logic for tool parsing here - compare prev. partially-parsed
+            #   JSON to the current partially-parsed JSON
+            prev_arguments = (
+                self.prev_tool_call_arr[self.current_tool_id].get("arguments"))
+            cur_arguments = current_tool_call.get("arguments")
+
+            logger.debug(f"diffing old arguments: {prev_arguments}")
+            logger.debug(f"against new ones: {cur_arguments}")
+
+            # case -- no arguments have been created yet. skip sending a delta.
+            if not cur_arguments and not prev_arguments:
+                logger.debug("Skipping text %s - no arguments", delta_text)
+                delta = None
+
+            # case -- prev arguments are defined, but non are now.
+            #   probably impossible, but not a fatal error - just keep going
+            elif not cur_arguments and prev_arguments:
+                logger.error("should be impossible to have arguments reset "
+                             "mid-call. skipping streaming anything.")
+                delta = None
+
+            # case -- we now have the first info about arguments available from
+            #   autocompleting the JSON
+            elif cur_arguments and not prev_arguments:
+
+                cur_arguments_json = json.dumps(cur_arguments)
+                logger.debug("finding %s in %s", delta_text,
+                             cur_arguments_json)
+
+                # get the location where previous args differ from current
+                args_delta_start_loc = cur_arguments_json.index(delta_text) \
+                                       + len(delta_text)
+
+                # use that to find the actual delta
+                arguments_delta = cur_arguments_json[:args_delta_start_loc]
+                logger.debug("First tokens in arguments received: "
+                             f"{arguments_delta}")
+
+                delta = DeltaMessage(tool_calls=[
+                    DeltaToolCall(index=self.current_tool_id,
+                                  function=DeltaFunctionCall(
+                                      arguments=arguments_delta).model_dump(
+                                          exclude_none=True))
+                ])
+                self.streamed_args_for_tool[self.current_tool_id] \
+                    += arguments_delta
+
+            # last case -- we have an update to existing arguments.
+            elif cur_arguments and prev_arguments:
+
+                cur_args_json = json.dumps(cur_arguments)
+                prev_args_json = json.dumps(prev_arguments)
+                logger.debug(f"Searching for diff between\n{cur_args_json}")
+                logger.debug(f"and\n{prev_args_json}")
+                argument_diff = extract_intermediate_diff(
+                    cur_args_json, prev_args_json)
+                logger.debug("got argument diff %s", argument_diff)
+                delta = DeltaMessage(tool_calls=[
+                    DeltaToolCall(index=self.current_tool_id,
+                                  function=DeltaFunctionCall(
+                                      arguments=argument_diff).model_dump(
+                                          exclude_none=True))
+                ])
+                self.streamed_args_for_tool[self.current_tool_id] \
+                    += argument_diff
+
+            # handle saving the state for the current tool into
+            # the "prev" list for use in diffing for the next iteration
+            if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
+                self.prev_tool_call_arr[self.current_tool_id] = \
+                    current_tool_call
+            else:
+                self.prev_tool_call_arr.append(current_tool_call)
+
+            return delta
+
+        except Exception as e:
+            logger.error(f"Error trying to handle streaming tool call: {e}")
+            return None  # do not stream a delta. skip this token ID.

+ 291 - 0
aphrodite/endpoints/openai/tool_parsers/mistral_tool_parser.py

@@ -0,0 +1,291 @@
+import json
+import re
+from typing import Dict, List, Sequence, Union
+
+import partial_json_parser
+from loguru import logger
+from partial_json_parser.core.options import Allow
+
+from aphrodite.endpoints.openai.protocol import (DeltaFunctionCall,
+                                                 DeltaMessage, DeltaToolCall,
+                                                 ExtractedToolCallInformation,
+                                                 FunctionCall,
+                                                 InitialDeltaToolCall,
+                                                 ToolCall)
+from aphrodite.endpoints.openai.tool_parsers.abstract_tool_parser import (
+    ToolParser)
+from aphrodite.endpoints.openai.tool_parsers.utils import (
+    extract_intermediate_diff)
+from aphrodite.transformers_utils.tokenizer import (AnyTokenizer,
+                                                    MistralTokenizer)
+
+
+class MistralToolParser(ToolParser):
+    """
+    Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
+    examples/tool_chat_template_mistral.jinja template.
+    Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
+    """
+
+    def __init__(self, tokenizer: AnyTokenizer):
+        super().__init__(tokenizer)
+
+        if isinstance(self.model_tokenizer, MistralTokenizer):
+            self.model_tokenizer = self.model_tokenizer.tokenizer
+        else:
+            logger.info("Non-Mistral tokenizer detected when using a Mistral "
+                        "model...")
+
+        # initialize properties used for state when parsing tool calls in
+        # streaming mode
+        self.prev_tool_call_arr: List[Dict] = []
+        self.current_tool_id: int = -1
+        self.current_tool_name_sent: bool = False
+        self.current_tool_initial_sent: bool = False
+        self.streamed_args_for_tool: List[str] = [
+        ]  # map what has been streamed for each tool so far to a list
+        self.bot_token = "[TOOL_CALLS]"
+        self.bot_token_id = self.model_tokenizer.vocab[self.bot_token]
+        self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
+
+    def extract_tool_calls(self,
+                           model_output: str) -> ExtractedToolCallInformation:
+        """
+        Extract the tool calls from a complete model response. Requires
+        find-and-replacing single quotes with double quotes for JSON parsing,
+        make sure your tool call arguments don't ever include quotes!
+        """
+
+        # case -- if a tool call token is not present, return a text response
+        if self.bot_token not in model_output:
+            return ExtractedToolCallInformation(tools_called=False,
+                                                tool_calls=[],
+                                                content=model_output)
+        try:
+
+            # use a regex to find the tool call. remove the BOT token
+            #   and make sure to replace single quotes with double quotes
+            raw_tool_call = self.tool_call_regex.findall(
+                model_output.replace(self.bot_token, ""))[0]
+
+            # load the JSON, and then use it to build the Function and
+            # Tool Call
+            function_call_arr = json.loads(raw_tool_call)
+            tool_calls: List[ToolCall] = [
+                ToolCall(
+                    type="function",
+                    function=FunctionCall(
+                        name=raw_function_call["name"],
+                        # function call args are JSON but as a string
+                        arguments=json.dumps(raw_function_call["arguments"])))
+                for raw_function_call in function_call_arr
+            ]
+
+            # get any content before  the tool call
+            content = model_output.split(self.bot_token)[0]
+            return ExtractedToolCallInformation(
+                tools_called=True,
+                tool_calls=tool_calls,
+                content=content if len(content) > 0 else None)
+
+        except Exception as e:
+            logger.error(f"Error in extracting tool call from response: {e}")
+            # return information to just treat the tool call as regular JSON
+            return ExtractedToolCallInformation(tools_called=False,
+                                                tool_calls=[],
+                                                content=model_output)
+
+    def extract_tool_calls_streaming(
+        self,
+        previous_text: str,
+        current_text: str,
+        delta_text: str,
+        previous_token_ids: Sequence[int],
+        current_token_ids: Sequence[int],
+        delta_token_ids: Sequence[int],
+    ) -> Union[DeltaMessage, None]:
+
+        # if the tool call token is not in the tokens generated so far, append
+        # output to contents since it's not a tool
+        if self.bot_token_id not in current_token_ids:
+            return DeltaMessage(content=delta_text)
+
+        # if the tool call token ID IS in the tokens generated so far, that
+        # means we're parsing as tool calls now
+
+        # handle if we detected the BOT token which means the start of tool
+        # calling
+        if (self.bot_token_id in delta_token_ids
+                and len(delta_token_ids) == 1):
+            # if it's the only token, return None, so we don't send a chat
+            # completion any don't send a control token
+            return None
+
+        # bit mask flags for partial JSON parsing. If the name hasn't been
+        # sent yet, don't allow sending
+        # an incomplete string since OpenAI only ever (as far as I have
+        # seen) allows sending the entire tool/ function name at once.
+        flags = Allow.ALL if self.current_tool_name_sent \
+            else Allow.ALL & ~Allow.STR
+        try:
+
+            # replace BOT token with empty string, and convert single quotes
+            # to double to allow parsing as JSON since mistral uses single
+            # quotes instead of double for tool calls
+            parsable_arr = current_text.split(self.bot_token)[1]
+
+            # tool calls are generated in an array, so do partial JSON
+            # parsing on the entire array
+            try:
+                tool_call_arr: List[Dict] = partial_json_parser.loads(
+                    parsable_arr, flags)
+            except partial_json_parser.core.exceptions.MalformedJSON:
+                logger.debug('not enough tokens to parse into JSON yet')
+                return None
+
+            # select as the current tool call the one we're on the state at
+
+            current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
+                if len(tool_call_arr) > 0 else {}
+
+            # case -- if no tokens have been streamed for the tool, e.g.
+            #   only the array brackets, stream nothing
+            if len(tool_call_arr) == 0:
+                return None
+
+            # case: we are starting a new tool in the array
+            #   -> array has > 0 length AND length has moved past cursor
+            elif (len(tool_call_arr) > 0
+                  and len(tool_call_arr) > self.current_tool_id + 1):
+
+                # if we're moving on to a new call, first make sure we
+                # haven't missed anything in the previous one that was
+                # auto-generated due to JSON completions, but wasn't
+                # streamed to the client yet.
+                if self.current_tool_id >= 0:
+                    diff: Union[str, None] = current_tool_call.get("arguments")
+
+                    if diff:
+                        diff = json.dumps(diff).replace(
+                            self.streamed_args_for_tool[self.current_tool_id],
+                            "")
+                        delta = DeltaMessage(tool_calls=[
+                            DeltaToolCall(index=self.current_tool_id,
+                                          function=DeltaFunctionCall(
+                                              arguments=diff).model_dump(
+                                                  exclude_none=True))
+                        ])
+                        self.streamed_args_for_tool[
+                            self.current_tool_id] += diff
+                    else:
+                        delta = None
+                else:
+                    delta = None
+                # re-set stuff pertaining to progress in the current tool
+                self.current_tool_id = len(tool_call_arr) - 1
+                self.current_tool_name_sent = False
+                self.current_tool_initial_sent = False
+                self.streamed_args_for_tool.append("")
+                logger.debug(f"starting on new tool {self.current_tool_id}")
+                return delta
+
+            # case: update an existing tool - this is handled below
+
+            # if the current tool initial data incl. the id, type=function
+            # and idx not sent, send that
+            if not self.current_tool_initial_sent:
+                self.current_tool_initial_sent = True
+                delta = DeltaMessage(tool_calls=[
+                    InitialDeltaToolCall(
+                        index=self.current_tool_id).model_dump(
+                            exclude_none=True)
+                ])
+
+            # if the current tool name hasn't been sent, send if available
+            # - otherwise send nothing
+            elif not self.current_tool_name_sent:
+                function_name = current_tool_call.get("name")
+                if function_name:
+
+                    delta = DeltaMessage(tool_calls=[
+                        DeltaToolCall(index=self.current_tool_id,
+                                      function=DeltaFunctionCall(
+                                          name=function_name).model_dump(
+                                              exclude_none=True))
+                    ])
+                    self.current_tool_name_sent = True
+                else:
+                    delta = None
+
+            # now we know we're on the same tool call and we're streaming
+            # arguments
+            else:
+
+                prev_arguments = self.prev_tool_call_arr[
+                    self.current_tool_id].get("arguments")
+                cur_arguments = current_tool_call.get("arguments")
+
+                new_text = delta_text.replace("\'", "\"")
+
+                if not cur_arguments and not prev_arguments:
+
+                    delta = None
+                elif not cur_arguments and prev_arguments:
+                    logger.error(
+                        "INVARIANT - impossible to have arguments reset "
+                        "mid-arguments")
+                    delta = None
+                elif cur_arguments and not prev_arguments:
+                    cur_arguments_json = json.dumps(cur_arguments)
+                    logger.debug("finding %s in %s", new_text,
+                                 cur_arguments_json)
+
+                    arguments_delta = cur_arguments_json[:cur_arguments_json.
+                                                         index(new_text) +
+                                                         len(new_text)]
+                    logger.debug("First tokens in arguments received: %s",
+                                 arguments_delta)
+                    delta = DeltaMessage(tool_calls=[
+                        DeltaToolCall(index=self.current_tool_id,
+                                      function=DeltaFunctionCall(
+                                          arguments=arguments_delta).
+                                      model_dump(exclude_none=True))
+                    ])
+                    self.streamed_args_for_tool[
+                        self.current_tool_id] += arguments_delta
+
+                elif cur_arguments and prev_arguments:
+                    cur_args_json = json.dumps(cur_arguments)
+                    prev_args_json = json.dumps(prev_arguments)
+                    logger.debug("Searching for diff between \n%s\n%s",
+                                 cur_args_json, prev_args_json)
+
+                    argument_diff = extract_intermediate_diff(
+                        cur_args_json, prev_args_json)
+                    logger.debug(f"got arguments diff: {argument_diff}")
+                    delta = DeltaMessage(tool_calls=[
+                        DeltaToolCall(index=self.current_tool_id,
+                                      function=DeltaFunctionCall(
+                                          arguments=argument_diff).model_dump(
+                                              exclude_none=True))
+                    ])
+                    self.streamed_args_for_tool[
+                        self.current_tool_id] += argument_diff
+                else:
+                    # try parsing it with regular JSON - if it works we're
+                    # at the end, and we need to send the difference between
+                    # tokens streamed so far and the valid JSON
+                    delta = None
+
+            # check to see if the name is defined and has been sent. if so,
+            # stream the name - otherwise keep waiting
+            # finish by setting old and returning None as base case
+            self.prev_tool_call_arr = tool_call_arr
+            return delta
+
+        except Exception as e:
+            logger.error(f"Error trying to handle streaming tool call: {e}")
+            logger.debug(
+                "Skipping chunk as a result of tool streaming extraction "
+                "error")
+            return None

+ 80 - 0
aphrodite/endpoints/openai/tool_parsers/utils.py

@@ -0,0 +1,80 @@
+def find_common_prefix(s1: str, s2: str) -> str:
+    """
+    Finds a common prefix that is shared between two strings, if there is one.
+    Order of arguments is NOT important.
+    This function is provided as a UTILITY for extracting information from JSON
+    generated by partial_json_parser, to help in ensuring that the right tokens
+    are returned in streaming, so that close-quotes, close-brackets and
+    close-braces are not returned prematurely.
+    e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
+    '{"fruit": "ap'
+    """
+    prefix = ''
+    min_length = min(len(s1), len(s2))
+    for i in range(0, min_length):
+        if s1[i] == s2[i]:
+            prefix += s1[i]
+        else:
+            break
+    return prefix
+
+
+def find_common_suffix(s1: str, s2: str) -> str:
+    """
+    Finds a common suffix shared between two strings, if there is one. Order of
+    arguments is NOT important.
+    Stops when the suffix ends OR it hits an alphanumeric character
+    e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
+    """
+    suffix = ''
+    min_length = min(len(s1), len(s2))
+    for i in range(1, min_length + 1):
+        if s1[-i] == s2[-i] and not s1[-i].isalnum():
+            suffix = s1[-i] + suffix
+        else:
+            break
+    return suffix
+
+
+def extract_intermediate_diff(curr: str, old: str) -> str:
+    """
+    Given two strings, extract the difference in the middle between two strings
+    that are known to have a common prefix and/or suffix.
+    This function is provided as a UTILITY for extracting information from JSON
+    generated by partial_json_parser, to help in ensuring that the right tokens
+    are returned in streaming, so that close-quotes, close-brackets and
+    close-braces are not returned prematurely. The order of arguments IS
+    important - the new version of the partially-parsed JSON must be the first
+    argument, and the secnod argument must be from the previous generation.
+    What it returns, is tokens that should be streamed to the client.
+    e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
+        -> 'ple'
+    """
+    suffix = find_common_suffix(curr, old)
+
+    old = old[::-1].replace(suffix[::-1], '', 1)[::-1]
+    prefix = find_common_prefix(curr, old)
+    diff = curr
+    if len(suffix):
+        diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1]
+
+    if len(prefix):
+        # replace the prefix only once in case it's mirrored
+        diff = diff.replace(prefix, '', 1)
+
+    return diff
+
+
+def find_all_indices(string, substring):
+    """
+    Find all (starting) indices of a substring in a given string. Useful for
+    tool call extraction
+    """
+    indices = []
+    index = -1
+    while True:
+        index = string.find(substring, index + 1)
+        if index == -1:
+            break
+        indices.append(index)
+    return indices

+ 3 - 2
aphrodite/modeling/guided_decoding/__init__.py

@@ -55,8 +55,9 @@ def _adapt_request_for_tool_use(request: Union[CompletionRequest,
     if type(request) is CompletionRequest:
         return request
 
-    # user has chosen to not use any tool
-    if request.tool_choice == "none":
+    # user has chosen to not use any tool,
+    # OR is allowing the model to choose a tool.
+    if request.tool_choice == "none" or request.tool_choice == "auto":
         return request
 
     # user has chosen to use a named tool

+ 23 - 8
aphrodite/modeling/guided_decoding/outlines_decoding.py

@@ -8,8 +8,9 @@ from typing import Tuple, Union
 from pydantic import BaseModel
 from transformers import PreTrainedTokenizerBase
 
-from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
-                                                 CompletionRequest)
+from aphrodite.endpoints.openai.protocol import (
+    ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
+    CompletionRequest)
 from aphrodite.modeling.guided_decoding.guided_fields import (
     GuidedDecodingRequest)
 from aphrodite.modeling.guided_decoding.outlines_logits_processors import (
@@ -101,15 +102,29 @@ def _get_guide_and_mode(
     request: Union[CompletionRequest, ChatCompletionRequest,
                    GuidedDecodingRequest]
 ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
-    if request.guided_json:
-        json = request.guided_json
-        if isinstance(json, dict):
+    # if the request is a chat completion request, AND the tool choice is a
+    # named tool choice, do guided decoding
+    #   using that tool as the JSON schema
+    if isinstance(request, ChatCompletionRequest) and isinstance(
+            request.tool_choice, ChatCompletionNamedToolChoiceParam):
+        # Guided generation for tools/functions parameters
+        if request.tool_choice.type == "function":
+            for tool in request.tools:
+                if (tool.type == "function" and tool.function.name
+                        == request.tool_choice.function.name):
+                    json = json_dumps(tool.function.parameters, sort_keys=True)
+                    return json, GuidedDecodingMode.JSON
+        return None, None
+    elif request.guided_json:
+        if isinstance(request.guided_json, dict):
             # turn dict into hashable string
-            json = json_dumps(json)
-        elif isinstance(json, BaseModel):
+            json = json_dumps(request.guided_json)
+        elif isinstance(request.guided_json, BaseModel):
             # use pydantic signature so that different model classes
             # with the same fields will get hashed the same
-            json = str(json.__signature__)
+            json = str(request.guided_json.__signature__)
+        else:
+            json = request.guided_json
         return json, GuidedDecodingMode.JSON
     elif request.guided_regex:
         return request.guided_regex, GuidedDecodingMode.REGEX

+ 1 - 0
requirements-common.txt

@@ -17,6 +17,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
 tiktoken >= 0.6.0
 lm-format-enforcer == 0.10.6
 outlines >= 0.0.43, < 0.1
+partial-json-parser # used for parsing partial JSON outputs
 typing_extensions >= 4.10
 filelock >= 3.10.4
 pyzmq