Parcourir la source

feat: OpenAI `tools` support named functions

AlpinDale il y a 7 mois
Parent
commit
141c602c39

+ 53 - 2
aphrodite/endpoints/openai/protocol.py

@@ -63,6 +63,26 @@ class ResponseFormat(BaseModel):
     type: str = Literal["text", "json_object"]
 
 
+class FunctionDefinition(BaseModel):
+    name: str
+    description: Optional[str] = None
+    parameters: Optional[Dict[str, Any]] = None
+
+
+class ChatCompletionToolsParam(BaseModel):
+    type: Literal["function"] = "function"
+    function: FunctionDefinition
+
+
+class ChatCompletionNamedFunction(BaseModel):
+    name: str
+
+
+class ChatCompletionNamedToolChoiceParam(BaseModel):
+    function: ChatCompletionNamedFunction
+    type: Literal["function"] = "function"
+
+
 class ChatCompletionRequest(BaseModel):
     model: str
     # support list type in messages.content
@@ -88,6 +108,9 @@ class ChatCompletionRequest(BaseModel):
     frequency_penalty: Optional[float] = 0.0
     repetition_penalty: Optional[float] = 1.0
     logit_bias: Optional[Dict[str, float]] = None
+    tools: Optional[List[ChatCompletionToolsParam]] = None
+    tool_choice: Optional[Union[Literal["none"],
+                                ChatCompletionNamedToolChoiceParam]] = "none"
     user: Optional[str] = None
     best_of: Optional[int] = None
     top_k: Optional[int] = -1
@@ -194,6 +217,22 @@ class ChatCompletionRequest(BaseModel):
             raise ValueError(
                 "You can only use one kind of guided decoding "
                 "('guided_json', 'guided_regex' or 'guided_choice').")
+        # you can only either use guided decoding or tools, not both
+        if guide_count > 1 and "tool_choice" in data and data[
+                "tool_choice"] != "none":
+            raise ValueError(
+                "You can only either use guided decoding or tools, not both.")
+        return data
+
+    @model_validator(mode="before")
+    @classmethod
+    def check_tool_choice(cls, data):
+        if "tool_choice" in data and data["tool_choice"] != "none":
+            if not isinstance(data["tool_choice"], dict):
+                raise ValueError("Currently only named tools are supported.")
+            if "tools" not in data or data["tools"] is None:
+                raise ValueError(
+                    "When using `tool_choice`, `tools` must be set.")
         return data
 
 
@@ -407,6 +446,17 @@ class EmbeddingResponse(BaseModel):
     usage: UsageInfo
 
 
+class FunctionCall(BaseModel):
+    name: str
+    arguments: str
+
+
+class ToolCall(BaseModel):
+    id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
+    type: Literal["function"] = "function"
+    function: FunctionCall
+
+
 class ChatMessage(BaseModel):
     role: str
     content: str
@@ -422,7 +472,7 @@ class ChatCompletionResponseChoice(BaseModel):
 
 class ChatCompletionResponse(BaseModel):
     id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
-    object: str = "chat.completion"
+    object: Literal["chat.completion"] = "chat.completion"
     created: int = Field(default_factory=lambda: int(time.time()))
     model: str
     choices: List[ChatCompletionResponseChoice]
@@ -432,6 +482,7 @@ class ChatCompletionResponse(BaseModel):
 class DeltaMessage(BaseModel):
     role: Optional[str] = None
     content: Optional[str] = None
+    tool_calls: List[ToolCall] = Field(default_factory=list)
 
 
 class ChatCompletionResponseStreamChoice(BaseModel):
@@ -444,7 +495,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
 
 class ChatCompletionStreamResponse(BaseModel):
     id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
-    object: str = "chat.completion.chunk"
+    object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
     created: int = Field(default_factory=lambda: int(time.time()))
     model: str
     choices: List[ChatCompletionResponseStreamChoice]

+ 33 - 15
aphrodite/endpoints/openai/serving_chat.py

@@ -8,20 +8,15 @@ from loguru import logger
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.utils import random_uuid
 from aphrodite.endpoints.openai.protocol import (
-    ChatCompletionRequest,
-    ChatCompletionResponse,
-    ChatCompletionResponseChoice,
-    ChatCompletionResponseStreamChoice,
-    ChatCompletionStreamResponse,
-    ChatMessage,
-    DeltaMessage,
-    ErrorResponse,
-    UsageInfo,
-)
+    ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
+    ChatCompletionResponse, ChatCompletionResponseChoice,
+    ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
+    ChatMessage, DeltaMessage, ErrorResponse, FunctionCall, ToolCall,
+    UsageInfo)
 from aphrodite.endpoints.openai.serving_engine import LoRA, OpenAIServing
 from aphrodite.engine.async_aphrodite import AsyncAphrodite
-from aphrodite.modeling.guided_decoding import (
-    get_guided_decoding_logits_processor)
+from aphrodite.modeling.guided_decoding import \
+    get_guided_decoding_logits_processor
 
 
 class OpenAIServingChat(OpenAIServing):
@@ -212,11 +207,21 @@ class OpenAIServingChat(OpenAIServing):
                     delta_text = output.text[len(previous_texts[i]):]
                     previous_texts[i] = output.text
                     previous_num_tokens[i] = len(output.token_ids)
+                    if request.tool_choice and type(
+                            request.tool_choice
+                    ) is ChatCompletionNamedToolChoiceParam:
+                        delta_message = DeltaMessage(tool_calls=[
+                            ToolCall(function=FunctionCall(
+                                name=request.tool_choice.function.name,
+                                arguments=delta_text))
+                        ])
+                    else:
+                        delta_message = DeltaMessage(content=delta_text)
                     if output.finish_reason is None:
                         # Send token-by-token response for each request.n
                         choice_data = ChatCompletionResponseStreamChoice(
                             index=i,
-                            delta=DeltaMessage(content=delta_text),
+                            delta=delta_message,
                             logprobs=logprobs,
                             finish_reason=None)
                         chunk = ChatCompletionStreamResponse(
@@ -238,7 +243,7 @@ class OpenAIServingChat(OpenAIServing):
                         )
                         choice_data = ChatCompletionResponseStreamChoice(
                             index=i,
-                            delta=DeltaMessage(content=delta_text),
+                            delta=delta_message,
                             logprobs=logprobs,
                             finish_reason=output.finish_reason,
                             stop_reason=output.stop_reason)
@@ -294,9 +299,22 @@ class OpenAIServingChat(OpenAIServing):
             else:
                 logprobs = None
 
+            if request.tool_choice and type(
+                    request.tool_choice) is ChatCompletionNamedToolChoiceParam:
+                message = ChatMessage(
+                    role=role,
+                    content="",
+                    tool_calls=[
+                        ToolCall(function=FunctionCall(
+                            name=request.tool_choice.function.name,
+                            arguments=output.text))
+                    ])
+            elif not request.tool_choice or request.tool_choice == "none":
+                message = ChatMessage(role=role, content=output.text)
+
             choice_data = ChatCompletionResponseChoice(
                 index=output.index,
-                message=ChatMessage(role=role, content=output.text),
+                message=message,
                 logprobs=logprobs,
                 finish_reason=output.finish_reason,
                 stop_reason=output.stop_reason,

+ 31 - 6
aphrodite/modeling/guided_decoding/__init__.py

@@ -1,18 +1,20 @@
 from typing import Optional, Union
 
-from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
-                                                 CompletionRequest)
-from aphrodite.modeling.guided_decoding.lm_format_enforcer_decoding import (
-    get_lm_format_enforcer_guided_decoding_logits_processor)
-from aphrodite.modeling.guided_decoding.outlines_decoding import (
-    get_outlines_guided_decoding_logits_processor)
 from aphrodite.common.sampling_params import LogitsProcessorFunc
+from aphrodite.endpoints.openai.protocol import (
+    ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
+    CompletionRequest)
+from aphrodite.modeling.guided_decoding.lm_format_enforcer_decoding import \
+    get_lm_format_enforcer_guided_decoding_logits_processor
+from aphrodite.modeling.guided_decoding.outlines_decoding import \
+    get_outlines_guided_decoding_logits_processor
 
 
 async def get_guided_decoding_logits_processor(
         guided_decoding_backend: str, request: Union[CompletionRequest,
                                                      ChatCompletionRequest],
         tokenizer) -> Optional[LogitsProcessorFunc]:
+    request = _adapt_request_for_tool_use(request)
     if guided_decoding_backend == 'outlines':
         return await get_outlines_guided_decoding_logits_processor(
             request, tokenizer)
@@ -23,3 +25,26 @@ async def get_guided_decoding_logits_processor(
     raise ValueError(
         f"Unknown guided decoding backend '{guided_decoding_backend}'. "
         "Must be one of 'outlines, 'lm-format-enforcer'")
+
+
+def _adapt_request_for_tool_use(request: Union[CompletionRequest,
+                                               ChatCompletionRequest]):
+    # the legacy completion API does not support tool use
+    if type(request) is CompletionRequest:
+        return request
+
+    # user has chosen to not use any tool
+    if request.tool_choice == "none":
+        return request
+
+    # user has chosen to use a named tool
+    if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam:
+        tool_name = request.tool_choice.function.name
+        tools = {tool.function.name: tool.function for tool in request.tools}
+        if tool_name not in tools:
+            raise ValueError(
+                f"Tool '{tool_name}' has not been passed in `tools`.")
+        tool = tools[tool_name]
+        request.guided_json = tool.parameters
+
+    return request