test_parallel_tool_calls.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import json
  2. from typing import Dict, List, Optional
  3. import openai
  4. import pytest
  5. from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
  6. MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL,
  7. WEATHER_TOOL)
  8. # test: getting the model to generate parallel tool calls (streaming/not)
  9. # when requested. NOTE that not all models may support this, so some exclusions
  10. # may be added in the future. e.g. llama 3.1 models are not designed to support
  11. # parallel tool calls.
  12. @pytest.mark.asyncio
  13. async def test_parallel_tool_calls(client: openai.AsyncOpenAI):
  14. models = await client.models.list()
  15. model_name: str = models.data[0].id
  16. chat_completion = await client.chat.completions.create(
  17. messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
  18. temperature=0,
  19. max_tokens=200,
  20. model=model_name,
  21. tools=[WEATHER_TOOL, SEARCH_TOOL],
  22. logprobs=False)
  23. choice = chat_completion.choices[0]
  24. stop_reason = chat_completion.choices[0].finish_reason
  25. non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls
  26. # make sure 2 tool calls are present
  27. assert choice.message.role == "assistant"
  28. assert non_streamed_tool_calls is not None
  29. assert len(non_streamed_tool_calls) == 2
  30. for tool_call in non_streamed_tool_calls:
  31. # make sure the tool includes a function and ID
  32. assert tool_call.type == "function"
  33. assert tool_call.function is not None
  34. assert isinstance(tool_call.id, str)
  35. assert len(tool_call.id) > 16
  36. # make sure the weather tool was called correctly
  37. assert tool_call.function.name == WEATHER_TOOL["function"]["name"]
  38. assert isinstance(tool_call.function.arguments, str)
  39. parsed_arguments = json.loads(tool_call.function.arguments)
  40. assert isinstance(parsed_arguments, Dict)
  41. assert isinstance(parsed_arguments.get("city"), str)
  42. assert isinstance(parsed_arguments.get("state"), str)
  43. assert stop_reason == "tool_calls"
  44. # make the same request, streaming
  45. stream = await client.chat.completions.create(
  46. model=model_name,
  47. messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
  48. temperature=0,
  49. max_tokens=200,
  50. tools=[WEATHER_TOOL, SEARCH_TOOL],
  51. logprobs=False,
  52. stream=True)
  53. role_name: Optional[str] = None
  54. finish_reason_count: int = 0
  55. tool_call_names: List[str] = []
  56. tool_call_args: List[str] = []
  57. tool_call_idx: int = -1
  58. tool_call_id_count: int = 0
  59. async for chunk in stream:
  60. # if there's a finish reason make sure it's tools
  61. if chunk.choices[0].finish_reason:
  62. finish_reason_count += 1
  63. assert chunk.choices[0].finish_reason == 'tool_calls'
  64. # if a role is being streamed make sure it wasn't already set to
  65. # something else
  66. if chunk.choices[0].delta.role:
  67. assert not role_name or role_name == 'assistant'
  68. role_name = 'assistant'
  69. # if a tool call is streamed make sure there's exactly one
  70. # (based on the request parameters
  71. streamed_tool_calls = chunk.choices[0].delta.tool_calls
  72. if streamed_tool_calls and len(streamed_tool_calls) > 0:
  73. # make sure only one diff is present - correct even for parallel
  74. assert len(streamed_tool_calls) == 1
  75. tool_call = streamed_tool_calls[0]
  76. # if a new tool is being called, set up empty arguments
  77. if tool_call.index != tool_call_idx:
  78. tool_call_idx = tool_call.index
  79. tool_call_args.append("")
  80. # if a tool call ID is streamed, make sure one hasn't been already
  81. if tool_call.id:
  82. tool_call_id_count += 1
  83. assert (isinstance(tool_call.id, str)
  84. and (len(tool_call.id) > 16))
  85. # if parts of the function start being streamed
  86. if tool_call.function:
  87. # if the function name is defined, set it. it should be streamed
  88. # IN ENTIRETY, exactly one time.
  89. if tool_call.function.name:
  90. assert isinstance(tool_call.function.name, str)
  91. tool_call_names.append(tool_call.function.name)
  92. if tool_call.function.arguments:
  93. # make sure they're a string and then add them to the list
  94. assert isinstance(tool_call.function.arguments, str)
  95. tool_call_args[
  96. tool_call.index] += tool_call.function.arguments
  97. assert finish_reason_count == 1
  98. assert role_name == 'assistant'
  99. assert (len(non_streamed_tool_calls) == len(tool_call_names) ==
  100. len(tool_call_args))
  101. for i in range(2):
  102. assert non_streamed_tool_calls[i].function.name == tool_call_names[i]
  103. streamed_args = json.loads(tool_call_args[i])
  104. non_streamed_args = json.loads(
  105. non_streamed_tool_calls[i].function.arguments)
  106. assert streamed_args == non_streamed_args
  107. # test: providing parallel tool calls back to the model to get a response
  108. # (streaming/not)
  109. @pytest.mark.asyncio
  110. async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI):
  111. models = await client.models.list()
  112. model_name: str = models.data[0].id
  113. chat_completion = await client.chat.completions.create(
  114. messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
  115. temperature=0,
  116. max_tokens=200,
  117. model=model_name,
  118. tools=[WEATHER_TOOL, SEARCH_TOOL],
  119. logprobs=False)
  120. choice = chat_completion.choices[0]
  121. assert choice.finish_reason != "tool_calls" # "stop" or "length"
  122. assert choice.message.role == "assistant"
  123. assert choice.message.tool_calls is None \
  124. or len(choice.message.tool_calls) == 0
  125. assert choice.message.content is not None
  126. assert "98" in choice.message.content # Dallas temp in tool response
  127. assert "78" in choice.message.content # Orlando temp in tool response
  128. stream = await client.chat.completions.create(
  129. messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
  130. temperature=0,
  131. max_tokens=200,
  132. model=model_name,
  133. tools=[WEATHER_TOOL, SEARCH_TOOL],
  134. logprobs=False,
  135. stream=True)
  136. chunks: List[str] = []
  137. finish_reason_count = 0
  138. role_sent: bool = False
  139. async for chunk in stream:
  140. delta = chunk.choices[0].delta
  141. if delta.role:
  142. assert not role_sent
  143. assert delta.role == "assistant"
  144. role_sent = True
  145. if delta.content:
  146. chunks.append(delta.content)
  147. if chunk.choices[0].finish_reason is not None:
  148. finish_reason_count += 1
  149. assert chunk.choices[0].finish_reason == choice.finish_reason
  150. assert not delta.tool_calls or len(delta.tool_calls) == 0
  151. assert role_sent
  152. assert finish_reason_count == 1
  153. assert len(chunks)
  154. assert "".join(chunks) == choice.message.content