tool_use.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. """
  2. Set up this example by starting a Aphrodite OpenAI-compatible server with tool
  3. call options enabled. For example:
  4. IMPORTANT: for mistral, you must use one of the provided mistral tool call
  5. templates, or your own - the model default doesn't work for tool calls with
  6. Aphrodite.
  7. See the Aphrodite docs on OpenAI server & tool calling for more details.
  8. aphrodite run mistralai/Mistral-7B-Instruct-v0.3 \
  9. --chat-template examples/chat_templates/mistral_tool.jinja \
  10. --enable-auto-tool-choice --tool-call-parser mistral
  11. OR
  12. aphrodite run NousResearch/Hermes-2-Pro-Llama-3-8B \
  13. --chat-template examples/chat_templates/hermes_tool.jinja \
  14. --enable-auto-tool-choice --tool-call-parser hermes
  15. """
  16. import json
  17. from openai import OpenAI
  18. # Modify OpenAI's API key and API base to use Aphrodite's API server.
  19. openai_api_key = "EMPTY"
  20. openai_api_base = "http://localhost:2242/v1"
  21. client = OpenAI(
  22. # defaults to os.environ.get("OPENAI_API_KEY")
  23. api_key=openai_api_key,
  24. base_url=openai_api_base,
  25. )
  26. models = client.models.list()
  27. model = models.data[0].id
  28. tools = [{
  29. "type": "function",
  30. "function": {
  31. "name": "get_current_weather",
  32. "description": "Get the current weather in a given location",
  33. "parameters": {
  34. "type": "object",
  35. "properties": {
  36. "city": {
  37. "type":
  38. "string",
  39. "description":
  40. "The city to find the weather for, e.g. 'San Francisco'"
  41. },
  42. "state": {
  43. "type":
  44. "string",
  45. "description":
  46. "the two-letter abbreviation for the state that the city is"
  47. " in, e.g. 'CA' which would mean 'California'"
  48. },
  49. "unit": {
  50. "type": "string",
  51. "description": "The unit to fetch the temperature in",
  52. "enum": ["celsius", "fahrenheit"]
  53. }
  54. },
  55. "required": ["city", "state", "unit"]
  56. }
  57. }
  58. }]
  59. messages = [{
  60. "role": "user",
  61. "content": "Hi! How are you doing today?"
  62. }, {
  63. "role": "assistant",
  64. "content": "I'm doing well! How can I help you?"
  65. }, {
  66. "role":
  67. "user",
  68. "content":
  69. "Can you tell me what the temperate will be in Dallas, in fahrenheit?"
  70. }]
  71. chat_completion = client.chat.completions.create(messages=messages,
  72. model=model,
  73. tools=tools)
  74. print("Chat completion results:")
  75. print(chat_completion)
  76. print("\n\n")
  77. tool_calls_stream = client.chat.completions.create(messages=messages,
  78. model=model,
  79. tools=tools,
  80. stream=True)
  81. chunks = []
  82. for chunk in tool_calls_stream:
  83. chunks.append(chunk)
  84. if chunk.choices[0].delta.tool_calls:
  85. print(chunk.choices[0].delta.tool_calls[0])
  86. else:
  87. print(chunk.choices[0].delta)
  88. arguments = []
  89. tool_call_idx = -1
  90. for chunk in chunks:
  91. if chunk.choices[0].delta.tool_calls:
  92. tool_call = chunk.choices[0].delta.tool_calls[0]
  93. if tool_call.index != tool_call_idx:
  94. if tool_call_idx >= 0:
  95. print(
  96. f"streamed tool call arguments: {arguments[tool_call_idx]}"
  97. )
  98. tool_call_idx = chunk.choices[0].delta.tool_calls[0].index
  99. arguments.append("")
  100. if tool_call.id:
  101. print(f"streamed tool call id: {tool_call.id} ")
  102. if tool_call.function:
  103. if tool_call.function.name:
  104. print(f"streamed tool call name: {tool_call.function.name}")
  105. if tool_call.function.arguments:
  106. arguments[tool_call_idx] += tool_call.function.arguments
  107. if len(arguments):
  108. print(f"streamed tool call arguments: {arguments[-1]}")
  109. print("\n\n")
  110. messages.append({
  111. "role": "assistant",
  112. "tool_calls": chat_completion.choices[0].message.tool_calls
  113. })
  114. # Now, simulate a tool call
  115. def get_current_weather(city: str, state: str, unit: 'str'):
  116. return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
  117. "partly cloudly, with highs in the 90's.")
  118. available_tools = {"get_current_weather": get_current_weather}
  119. completion_tool_calls = chat_completion.choices[0].message.tool_calls
  120. for call in completion_tool_calls:
  121. tool_to_call = available_tools[call.function.name]
  122. args = json.loads(call.function.arguments)
  123. result = tool_to_call(**args)
  124. print(result)
  125. messages.append({
  126. "role": "tool",
  127. "content": result,
  128. "tool_call_id": call.id,
  129. "name": call.function.name
  130. })
  131. chat_completion_2 = client.chat.completions.create(messages=messages,
  132. model=model,
  133. tools=tools,
  134. stream=False)
  135. print("\n\n")
  136. print(chat_completion_2)