Browse Source

protocol and abstract class

AlpinDale 6 months ago
parent
commit
5630aa378b

+ 112 - 11
aphrodite/endpoints/openai/protocol.py

@@ -4,10 +4,11 @@ import time
 from typing import Any, Dict, List, Literal, Optional, Union
 
 import torch
+from openai.types.chat import ChatCompletionContentPartParam
 from pydantic import (BaseModel, ConfigDict, Field, model_validator,
                       root_validator)
 from transformers import PreTrainedTokenizer
-from typing_extensions import Annotated
+from typing_extensions import Annotated, Required, TypedDict
 
 from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.sampling_params import (LogitsProcessorFunc,
@@ -18,6 +19,25 @@ from aphrodite.endpoints.chat_utils import ChatCompletionMessageParam
 from aphrodite.endpoints.openai.logits_processors import get_logits_processors
 
 
+class CustomChatCompletionMessageParam(TypedDict, total=False):
+    """Enables custom roles in the Chat Completion API."""
+    role: Required[str]
+    """The role of the message's author."""
+
+    content: Union[str, List[ChatCompletionContentPartParam]]
+    """The contents of the message."""
+
+    name: str
+    """An optional name for the participant.
+    Provides the model information to differentiate between participants of the
+    same role.
+    """
+
+    tool_call_id: Optional[str]
+
+    tool_calls: Optional[List[dict]]
+
+
 class OpenAIBaseModel(BaseModel):
     model_config = ConfigDict(extra="ignore")
 
@@ -119,8 +139,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
     temperature: Optional[float] = 0.7
     top_p: Optional[float] = 1.0
     tools: Optional[List[ChatCompletionToolsParam]] = None
-    tool_choice: Optional[Union[Literal["none"],
+    tool_choice: Optional[Union[Literal["none"], Literal["auto"],
                                 ChatCompletionNamedToolChoiceParam]] = "none"
+
+    # NOTE this will be ignored by Aphrodite -- the model determines
+    # the behavior
+    parallel_tool_calls: Optional[bool] = False
     user: Optional[str] = None
 
     # doc: begin-chat-completion-sampling-params
@@ -297,6 +321,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
     @model_validator(mode="before")
     @classmethod
     def check_guided_decoding_count(cls, data):
+        if isinstance(data, ValueError):
+            raise data
+
         guide_count = sum([
             "guided_json" in data and data["guided_json"] is not None,
             "guided_regex" in data and data["guided_regex"] is not None,
@@ -308,21 +335,61 @@ class ChatCompletionRequest(OpenAIBaseModel):
                 "You can only use one kind of guided decoding "
                 "('guided_json', 'guided_regex' or 'guided_choice').")
         # you can only either use guided decoding or tools, not both
-        if guide_count > 1 and "tool_choice" in data and data[
-                "tool_choice"] != "none":
+        if guide_count > 1 and data.get("tool_choice",
+                                        "none") not in ("none", "auto"):
             raise ValueError(
                 "You can only either use guided decoding or tools, not both.")
         return data
 
     @model_validator(mode="before")
     @classmethod
-    def check_tool_choice(cls, data):
-        if "tool_choice" in data and data["tool_choice"] != "none":
-            if not isinstance(data["tool_choice"], dict):
-                raise ValueError("Currently only named tools are supported.")
+    def check_tool_usage(cls, data):
+
+        # if "tool_choice" is not specified but tools are provided,
+        # default to "auto" tool_choice
+        if "tool_choice" not in data and "tools" in data:
+            data["tool_choice"] = "auto"
+
+        # if "tool_choice" is specified -- validation
+        if "tool_choice" in data:
+
+            # ensure that if "tool choice" is specified, tools are present
             if "tools" not in data or data["tools"] is None:
                 raise ValueError(
                     "When using `tool_choice`, `tools` must be set.")
+
+            # make sure that tool choice is either a named tool
+            # OR that it's set to "auto"
+            if data["tool_choice"] != "auto" and not isinstance(
+                    data["tool_choice"], dict):
+                raise ValueError(
+                    "`tool_choice` must either be a named tool or \"auto\". "
+                    "`tool_choice=\"none\" is not supported.")
+
+            # ensure that if "tool_choice" is specified as an object,
+            # it matches a valid tool
+            if isinstance(data["tool_choice"], dict):
+                valid_tool = False
+                specified_function = data["tool_choice"]["function"]
+                if not specified_function:
+                    raise ValueError(
+                        "Incorrectly formatted `tool_choice`. Should be like "
+                        "`{\"type\": \"function\","
+                        " \"function\": {\"name\": \"my_function\"}}`")
+                specified_function_name = specified_function["name"]
+                if not specified_function_name:
+                    raise ValueError(
+                        "Incorrectly formatted `tool_choice`. Should be like "
+                        "`{\"type\": \"function\", "
+                        "\"function\": {\"name\": \"my_function\"}}`")
+                for tool in data["tools"]:
+                    if tool["function"]["name"] == specified_function_name:
+                        valid_tool = True
+                        break
+                if not valid_tool:
+                    raise ValueError(
+                        "The tool specified in `tool_choice` does not match any"
+                        " of the specified `tools`")
         return data
 
     @model_validator(mode="before")
@@ -616,9 +683,41 @@ class ToolCall(OpenAIBaseModel):
     function: FunctionCall
 
 
+class DeltaFunctionCall(BaseModel):
+    name: Optional[str] = None
+    arguments: Optional[str] = None
+
+
+# a tool call delta where everything is optional
+class DeltaToolCall(OpenAIBaseModel):
+    id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
+    type: Literal["function"] = "function"
+    index: int
+    function: Optional[DeltaFunctionCall] = None
+
+
+# the initial delta that gets sent once a new tool call is started;
+class InitialDeltaToolCall(DeltaToolCall):
+    id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
+    type: Literal["function"] = "function"
+    index: int
+
+
+class ExtractedToolCallInformation(BaseModel):
+    # indicate if tools were called
+    tools_called: bool
+
+    # extracted tool calls
+    tool_calls: List[ToolCall]
+
+    # content - per OpenAI spec, content AND tool calls can be returned rarely
+    # But some models will do this intentionally
+    content: Optional[str] = None
+
+
 class ChatMessage(OpenAIBaseModel):
     role: str
-    content: str
+    content: Optional[str] = None
     tool_calls: List[ToolCall] = Field(default_factory=list)
 
 
@@ -640,7 +739,9 @@ class ChatCompletionResponseChoice(OpenAIBaseModel):
     index: int
     message: ChatMessage
     logprobs: Optional[ChatCompletionLogProbs] = None
-    finish_reason: Optional[str] = None
+    # per OpenAI spec this is the default
+    finish_reason: Optional[str] = "stop"
+    # not part of the OpenAI spec but included in Aphrodite for legacy reasons
     stop_reason: Optional[Union[int, str]] = None
 
 
@@ -657,7 +758,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
 class DeltaMessage(OpenAIBaseModel):
     role: Optional[str] = None
     content: Optional[str] = None
-    tool_calls: List[ToolCall] = Field(default_factory=list)
+    tool_calls: List[DeltaToolCall] = Field(default_factory=list)
 
 
 class ChatCompletionResponseStreamChoice(OpenAIBaseModel):

+ 0 - 0
aphrodite/endpoints/openai/tool_parsers/__init__.py


+ 55 - 0
aphrodite/endpoints/openai/tool_parsers/abstract_tool_parser.py

@@ -0,0 +1,55 @@
+from typing import Dict, List, Sequence, Union
+
+from aphrodite.endpoints.openai.protocol import (DeltaMessage,
+                                                 ExtractedToolCallInformation)
+from aphrodite.transformers_utils.tokenizer import AnyTokenizer
+
+
+class ToolParser:
+    """
+    Abstract ToolParser class that should not be used directly. Provided
+    properties and methods should be used in
+    derived classes.
+    """
+
+    def __init__(self, tokenizer: AnyTokenizer):
+        self.prev_tool_call_arr: List[Dict] = []
+        # the index of the tool call that is currently being parsed
+        self.current_tool_id: int = -1
+        self.current_tool_name_sent: bool = False
+        self.current_tool_initial_sent: bool = False
+        self.streamed_args_for_tool: List[str] = []
+
+        self.model_tokenizer = tokenizer
+
+    def extract_tool_calls(self,
+                           model_output: str) -> ExtractedToolCallInformation:
+        """
+        Static method that should be implemented for extracting tool calls from
+        a complete model-generated string.
+        Used for non-streaming responses where we have the entire model response
+        available before sending to the client.
+        Static because it's stateless.
+        """
+        raise NotImplementedError(
+            "AbstractToolParser.extract_tool_calls has not been implemented!")
+
+    def extract_tool_calls_streaming(
+        self,
+        previous_text: str,
+        current_text: str,
+        delta_text: str,
+        previous_token_ids: Sequence[int],
+        current_token_ids: Sequence[int],
+        delta_token_ids: Sequence[int],
+    ) -> Union[DeltaMessage, None]:
+        """
+        Instance method that should be implemented for extracting tool calls
+        from an incomplete response; for use when handling tool calls and
+        streaming. Has to be an instance method because  it requires state -
+        the current tokens/diffs, but also the information about what has
+        previously been parsed and extracted (see constructor)
+        """
+        raise NotImplementedError(
+            "AbstractToolParser.extract_tool_calls_streaming has not been "
+            "implemented!")