123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192 |
- import json
- from typing import Dict, List, Optional
- import openai
- import pytest
- from .utils import (MESSAGES_ASKING_FOR_TOOLS, MESSAGES_WITH_TOOL_RESPONSE,
- SEARCH_TOOL, WEATHER_TOOL)
- # test: request a chat completion that should return tool calls, so we know they
- # are parsable
- @pytest.mark.asyncio
- async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
- models = await client.models.list()
- model_name: str = models.data[0].id
- chat_completion = await client.chat.completions.create(
- messages=MESSAGES_ASKING_FOR_TOOLS,
- temperature=0,
- max_tokens=100,
- model=model_name,
- tools=[WEATHER_TOOL, SEARCH_TOOL],
- logprobs=False)
- choice = chat_completion.choices[0]
- stop_reason = chat_completion.choices[0].finish_reason
- tool_calls = chat_completion.choices[0].message.tool_calls
- # make sure a tool call is present
- assert choice.message.role == 'assistant'
- assert tool_calls is not None
- assert len(tool_calls) == 1
- assert tool_calls[0].type == 'function'
- assert tool_calls[0].function is not None
- assert isinstance(tool_calls[0].id, str)
- assert len(tool_calls[0].id) > 16
- # make sure the weather tool was called (classic example) with arguments
- assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"]
- assert tool_calls[0].function.arguments is not None
- assert isinstance(tool_calls[0].function.arguments, str)
- # make sure the arguments parse properly
- parsed_arguments = json.loads(tool_calls[0].function.arguments)
- assert isinstance(parsed_arguments, Dict)
- assert isinstance(parsed_arguments.get("city"), str)
- assert isinstance(parsed_arguments.get("state"), str)
- assert parsed_arguments.get("city") == "Dallas"
- assert parsed_arguments.get("state") == "TX"
- assert stop_reason == "tool_calls"
- function_name: Optional[str] = None
- function_args_str: str = ''
- tool_call_id: Optional[str] = None
- role_name: Optional[str] = None
- finish_reason_count: int = 0
- # make the same request, streaming
- stream = await client.chat.completions.create(
- model=model_name,
- messages=MESSAGES_ASKING_FOR_TOOLS,
- temperature=0,
- max_tokens=100,
- tools=[WEATHER_TOOL, SEARCH_TOOL],
- logprobs=False,
- stream=True)
- async for chunk in stream:
- assert chunk.choices[0].index == 0
- if chunk.choices[0].finish_reason:
- finish_reason_count += 1
- assert chunk.choices[0].finish_reason == 'tool_calls'
- # if a role is being streamed make sure it wasn't already set to
- # something else
- if chunk.choices[0].delta.role:
- assert not role_name or role_name == 'assistant'
- role_name = 'assistant'
- # if a tool call is streamed make sure there's exactly one
- # (based on the request parameters
- streamed_tool_calls = chunk.choices[0].delta.tool_calls
- if streamed_tool_calls and len(streamed_tool_calls) > 0:
- assert len(streamed_tool_calls) == 1
- tool_call = streamed_tool_calls[0]
- # if a tool call ID is streamed, make sure one hasn't been already
- if tool_call.id:
- assert not tool_call_id
- tool_call_id = tool_call.id
- # if parts of the function start being streamed
- if tool_call.function:
- # if the function name is defined, set it. it should be streamed
- # IN ENTIRETY, exactly one time.
- if tool_call.function.name:
- assert function_name is None
- assert isinstance(tool_call.function.name, str)
- function_name = tool_call.function.name
- if tool_call.function.arguments:
- assert isinstance(tool_call.function.arguments, str)
- function_args_str += tool_call.function.arguments
- assert finish_reason_count == 1
- assert role_name == 'assistant'
- assert isinstance(tool_call_id, str) and (len(tool_call_id) > 16)
- # validate the name and arguments
- assert function_name == WEATHER_TOOL["function"]["name"]
- assert function_name == tool_calls[0].function.name
- assert isinstance(function_args_str, str)
- # validate arguments
- streamed_args = json.loads(function_args_str)
- assert isinstance(streamed_args, Dict)
- assert isinstance(streamed_args.get("city"), str)
- assert isinstance(streamed_args.get("state"), str)
- assert streamed_args.get("city") == "Dallas"
- assert streamed_args.get("state") == "TX"
- # make sure everything matches non-streaming except for ID
- assert function_name == tool_calls[0].function.name
- assert choice.message.role == role_name
- assert choice.message.tool_calls[0].function.name == function_name
- # compare streamed with non-streamed args Dict-wise, not string-wise
- # because character-to-character comparison might not work e.g. the tool
- # call parser adding extra spaces or something like that. we care about the
- # dicts matching not byte-wise match
- assert parsed_arguments == streamed_args
- # test: providing tools and results back to model to get a non-tool response
- # (streaming/not)
- @pytest.mark.asyncio
- async def test_tool_call_with_results(client: openai.AsyncOpenAI):
- models = await client.models.list()
- model_name: str = models.data[0].id
- chat_completion = await client.chat.completions.create(
- messages=MESSAGES_WITH_TOOL_RESPONSE,
- temperature=0,
- max_tokens=100,
- model=model_name,
- tools=[WEATHER_TOOL, SEARCH_TOOL],
- logprobs=False)
- choice = chat_completion.choices[0]
- assert choice.finish_reason != "tool_calls" # "stop" or "length"
- assert choice.message.role == "assistant"
- assert choice.message.tool_calls is None \
- or len(choice.message.tool_calls) == 0
- assert choice.message.content is not None
- assert "98" in choice.message.content # the temperature from the response
- stream = await client.chat.completions.create(
- messages=MESSAGES_WITH_TOOL_RESPONSE,
- temperature=0,
- max_tokens=100,
- model=model_name,
- tools=[WEATHER_TOOL, SEARCH_TOOL],
- logprobs=False,
- stream=True)
- chunks: List[str] = []
- finish_reason_count = 0
- role_sent: bool = False
- async for chunk in stream:
- delta = chunk.choices[0].delta
- if delta.role:
- assert not role_sent
- assert delta.role == "assistant"
- role_sent = True
- if delta.content:
- chunks.append(delta.content)
- if chunk.choices[0].finish_reason is not None:
- finish_reason_count += 1
- assert chunk.choices[0].finish_reason == choice.finish_reason
- assert not delta.tool_calls or len(delta.tool_calls) == 0
- assert role_sent
- assert finish_reason_count == 1
- assert len(chunks)
- assert "".join(chunks) == choice.message.content
|