test_tool_calls.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import json
  2. from typing import Dict, List, Optional
  3. import openai
  4. import pytest
  5. from .utils import (MESSAGES_ASKING_FOR_TOOLS, MESSAGES_WITH_TOOL_RESPONSE,
  6. SEARCH_TOOL, WEATHER_TOOL)
  7. # test: request a chat completion that should return tool calls, so we know they
  8. # are parsable
  9. @pytest.mark.asyncio
  10. async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
  11. models = await client.models.list()
  12. model_name: str = models.data[0].id
  13. chat_completion = await client.chat.completions.create(
  14. messages=MESSAGES_ASKING_FOR_TOOLS,
  15. temperature=0,
  16. max_tokens=100,
  17. model=model_name,
  18. tools=[WEATHER_TOOL, SEARCH_TOOL],
  19. logprobs=False)
  20. choice = chat_completion.choices[0]
  21. stop_reason = chat_completion.choices[0].finish_reason
  22. tool_calls = chat_completion.choices[0].message.tool_calls
  23. # make sure a tool call is present
  24. assert choice.message.role == 'assistant'
  25. assert tool_calls is not None
  26. assert len(tool_calls) == 1
  27. assert tool_calls[0].type == 'function'
  28. assert tool_calls[0].function is not None
  29. assert isinstance(tool_calls[0].id, str)
  30. assert len(tool_calls[0].id) > 16
  31. # make sure the weather tool was called (classic example) with arguments
  32. assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"]
  33. assert tool_calls[0].function.arguments is not None
  34. assert isinstance(tool_calls[0].function.arguments, str)
  35. # make sure the arguments parse properly
  36. parsed_arguments = json.loads(tool_calls[0].function.arguments)
  37. assert isinstance(parsed_arguments, Dict)
  38. assert isinstance(parsed_arguments.get("city"), str)
  39. assert isinstance(parsed_arguments.get("state"), str)
  40. assert parsed_arguments.get("city") == "Dallas"
  41. assert parsed_arguments.get("state") == "TX"
  42. assert stop_reason == "tool_calls"
  43. function_name: Optional[str] = None
  44. function_args_str: str = ''
  45. tool_call_id: Optional[str] = None
  46. role_name: Optional[str] = None
  47. finish_reason_count: int = 0
  48. # make the same request, streaming
  49. stream = await client.chat.completions.create(
  50. model=model_name,
  51. messages=MESSAGES_ASKING_FOR_TOOLS,
  52. temperature=0,
  53. max_tokens=100,
  54. tools=[WEATHER_TOOL, SEARCH_TOOL],
  55. logprobs=False,
  56. stream=True)
  57. async for chunk in stream:
  58. assert chunk.choices[0].index == 0
  59. if chunk.choices[0].finish_reason:
  60. finish_reason_count += 1
  61. assert chunk.choices[0].finish_reason == 'tool_calls'
  62. # if a role is being streamed make sure it wasn't already set to
  63. # something else
  64. if chunk.choices[0].delta.role:
  65. assert not role_name or role_name == 'assistant'
  66. role_name = 'assistant'
  67. # if a tool call is streamed make sure there's exactly one
  68. # (based on the request parameters
  69. streamed_tool_calls = chunk.choices[0].delta.tool_calls
  70. if streamed_tool_calls and len(streamed_tool_calls) > 0:
  71. assert len(streamed_tool_calls) == 1
  72. tool_call = streamed_tool_calls[0]
  73. # if a tool call ID is streamed, make sure one hasn't been already
  74. if tool_call.id:
  75. assert not tool_call_id
  76. tool_call_id = tool_call.id
  77. # if parts of the function start being streamed
  78. if tool_call.function:
  79. # if the function name is defined, set it. it should be streamed
  80. # IN ENTIRETY, exactly one time.
  81. if tool_call.function.name:
  82. assert function_name is None
  83. assert isinstance(tool_call.function.name, str)
  84. function_name = tool_call.function.name
  85. if tool_call.function.arguments:
  86. assert isinstance(tool_call.function.arguments, str)
  87. function_args_str += tool_call.function.arguments
  88. assert finish_reason_count == 1
  89. assert role_name == 'assistant'
  90. assert isinstance(tool_call_id, str) and (len(tool_call_id) > 16)
  91. # validate the name and arguments
  92. assert function_name == WEATHER_TOOL["function"]["name"]
  93. assert function_name == tool_calls[0].function.name
  94. assert isinstance(function_args_str, str)
  95. # validate arguments
  96. streamed_args = json.loads(function_args_str)
  97. assert isinstance(streamed_args, Dict)
  98. assert isinstance(streamed_args.get("city"), str)
  99. assert isinstance(streamed_args.get("state"), str)
  100. assert streamed_args.get("city") == "Dallas"
  101. assert streamed_args.get("state") == "TX"
  102. # make sure everything matches non-streaming except for ID
  103. assert function_name == tool_calls[0].function.name
  104. assert choice.message.role == role_name
  105. assert choice.message.tool_calls[0].function.name == function_name
  106. # compare streamed with non-streamed args Dict-wise, not string-wise
  107. # because character-to-character comparison might not work e.g. the tool
  108. # call parser adding extra spaces or something like that. we care about the
  109. # dicts matching not byte-wise match
  110. assert parsed_arguments == streamed_args
  111. # test: providing tools and results back to model to get a non-tool response
  112. # (streaming/not)
  113. @pytest.mark.asyncio
  114. async def test_tool_call_with_results(client: openai.AsyncOpenAI):
  115. models = await client.models.list()
  116. model_name: str = models.data[0].id
  117. chat_completion = await client.chat.completions.create(
  118. messages=MESSAGES_WITH_TOOL_RESPONSE,
  119. temperature=0,
  120. max_tokens=100,
  121. model=model_name,
  122. tools=[WEATHER_TOOL, SEARCH_TOOL],
  123. logprobs=False)
  124. choice = chat_completion.choices[0]
  125. assert choice.finish_reason != "tool_calls" # "stop" or "length"
  126. assert choice.message.role == "assistant"
  127. assert choice.message.tool_calls is None \
  128. or len(choice.message.tool_calls) == 0
  129. assert choice.message.content is not None
  130. assert "98" in choice.message.content # the temperature from the response
  131. stream = await client.chat.completions.create(
  132. messages=MESSAGES_WITH_TOOL_RESPONSE,
  133. temperature=0,
  134. max_tokens=100,
  135. model=model_name,
  136. tools=[WEATHER_TOOL, SEARCH_TOOL],
  137. logprobs=False,
  138. stream=True)
  139. chunks: List[str] = []
  140. finish_reason_count = 0
  141. role_sent: bool = False
  142. async for chunk in stream:
  143. delta = chunk.choices[0].delta
  144. if delta.role:
  145. assert not role_sent
  146. assert delta.role == "assistant"
  147. role_sent = True
  148. if delta.content:
  149. chunks.append(delta.content)
  150. if chunk.choices[0].finish_reason is not None:
  151. finish_reason_count += 1
  152. assert chunk.choices[0].finish_reason == choice.finish_reason
  153. assert not delta.tool_calls or len(delta.tool_calls) == 0
  154. assert role_sent
  155. assert finish_reason_count == 1
  156. assert len(chunks)
  157. assert "".join(chunks) == choice.message.content