mistral_tool_parser.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. import json
  2. import re
  3. from typing import Dict, List, Sequence, Union
  4. import partial_json_parser
  5. from loguru import logger
  6. from partial_json_parser.core.options import Allow
  7. from aphrodite.common.utils import random_uuid
  8. from aphrodite.endpoints.openai.protocol import (DeltaFunctionCall,
  9. DeltaMessage, DeltaToolCall,
  10. ExtractedToolCallInformation,
  11. FunctionCall, ToolCall)
  12. from aphrodite.endpoints.openai.tool_parsers.abstract_tool_parser import (
  13. ToolParser)
  14. from aphrodite.endpoints.openai.tool_parsers.utils import (
  15. extract_intermediate_diff)
  16. from aphrodite.transformers_utils.tokenizer import (AnyTokenizer,
  17. MistralTokenizer)
  18. class MistralToolParser(ToolParser):
  19. """
  20. Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
  21. examples/tool_chat_template_mistral.jinja template.
  22. Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
  23. """
  24. def __init__(self, tokenizer: AnyTokenizer):
  25. super().__init__(tokenizer)
  26. if isinstance(self.model_tokenizer, MistralTokenizer):
  27. self.model_tokenizer = self.model_tokenizer.tokenizer
  28. else:
  29. logger.info(
  30. "Non-Mistral tokenizer detected when using a Mistral "
  31. "model..."
  32. )
  33. # initialize properties used for state when parsing tool calls in
  34. # streaming mode
  35. self.prev_tool_call_arr: List[Dict] = []
  36. self.current_tool_id: int = -1
  37. self.current_tool_name_sent: bool = False
  38. self.streamed_args_for_tool: List[
  39. str
  40. ] = [] # map what has been streamed for each tool so far to a list
  41. self.bot_token = "[TOOL_CALLS]"
  42. self.bot_token_id = self.model_tokenizer.vocab[self.bot_token]
  43. self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
  44. def extract_tool_calls(
  45. self, model_output: str
  46. ) -> ExtractedToolCallInformation:
  47. """
  48. Extract the tool calls from a complete model response. Requires
  49. find-and-replacing single quotes with double quotes for JSON parsing,
  50. make sure your tool call arguments don't ever include quotes!
  51. """
  52. # case -- if a tool call token is not present, return a text response
  53. if self.bot_token not in model_output:
  54. return ExtractedToolCallInformation(
  55. tools_called=False, tool_calls=[], content=model_output
  56. )
  57. try:
  58. # use a regex to find the tool call. remove the BOT token
  59. # and make sure to replace single quotes with double quotes
  60. raw_tool_call = self.tool_call_regex.findall(
  61. model_output.replace(self.bot_token, "")
  62. )[0]
  63. # load the JSON, and then use it to build the Function and
  64. # Tool Call
  65. function_call_arr = json.loads(raw_tool_call)
  66. tool_calls: List[ToolCall] = [
  67. ToolCall(
  68. type="function",
  69. function=FunctionCall(
  70. name=raw_function_call["name"],
  71. # function call args are JSON but as a string
  72. arguments=json.dumps(raw_function_call["arguments"]),
  73. ),
  74. )
  75. for raw_function_call in function_call_arr
  76. ]
  77. # get any content before the tool call
  78. content = model_output.split(self.bot_token)[0]
  79. return ExtractedToolCallInformation(
  80. tools_called=True,
  81. tool_calls=tool_calls,
  82. content=content if len(content) > 0 else None,
  83. )
  84. except Exception as e:
  85. logger.error(f"Error in extracting tool call from response: {e}")
  86. # return information to just treat the tool call as regular JSON
  87. return ExtractedToolCallInformation(
  88. tools_called=False, tool_calls=[], content=model_output
  89. )
  90. def extract_tool_calls_streaming(
  91. self,
  92. previous_text: str,
  93. current_text: str,
  94. delta_text: str,
  95. previous_token_ids: Sequence[int],
  96. current_token_ids: Sequence[int],
  97. delta_token_ids: Sequence[int],
  98. ) -> Union[DeltaMessage, None]:
  99. # if the tool call token is not in the tokens generated so far, append
  100. # output to contents since it's not a tool
  101. if self.bot_token not in current_text:
  102. return DeltaMessage(content=delta_text)
  103. # if the tool call token ID IS in the tokens generated so far, that
  104. # means we're parsing as tool calls now
  105. # handle if we detected the BOT token which means the start of tool
  106. # calling
  107. if self.bot_token_id in delta_token_ids and len(delta_token_ids) == 1:
  108. # if it's the only token, return None, so we don't send a chat
  109. # completion any don't send a control token
  110. return None
  111. # bit mask flags for partial JSON parsing. If the name hasn't been
  112. # sent yet, don't allow sending
  113. # an incomplete string since OpenAI only ever (as far as I have
  114. # seen) allows sending the entire tool/ function name at once.
  115. flags = (
  116. Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
  117. )
  118. try:
  119. # replace BOT token with empty string, and convert single quotes
  120. # to double to allow parsing as JSON since mistral uses single
  121. # quotes instead of double for tool calls
  122. parsable_arr = current_text.split(self.bot_token)[-1]
  123. # tool calls are generated in an array, so do partial JSON
  124. # parsing on the entire array
  125. try:
  126. tool_call_arr: List[Dict] = partial_json_parser.loads(
  127. parsable_arr, flags
  128. )
  129. except partial_json_parser.core.exceptions.MalformedJSON:
  130. logger.debug("not enough tokens to parse into JSON yet")
  131. return None
  132. # select as the current tool call the one we're on the state at
  133. current_tool_call: Dict = (
  134. tool_call_arr[self.current_tool_id]
  135. if len(tool_call_arr) > 0
  136. else {}
  137. )
  138. # case -- if no tokens have been streamed for the tool, e.g.
  139. # only the array brackets, stream nothing
  140. if len(tool_call_arr) == 0:
  141. return None
  142. # case: we are starting a new tool in the array
  143. # -> array has > 0 length AND length has moved past cursor
  144. elif (
  145. len(tool_call_arr) > 0
  146. and len(tool_call_arr) > self.current_tool_id + 1
  147. ):
  148. # if we're moving on to a new call, first make sure we
  149. # haven't missed anything in the previous one that was
  150. # auto-generated due to JSON completions, but wasn't
  151. # streamed to the client yet.
  152. if self.current_tool_id >= 0:
  153. diff: Union[str, None] = current_tool_call.get("arguments")
  154. if diff:
  155. diff = json.dumps(diff).replace(
  156. self.streamed_args_for_tool[self.current_tool_id],
  157. "",
  158. )
  159. delta = DeltaMessage(
  160. tool_calls=[
  161. DeltaToolCall(
  162. index=self.current_tool_id,
  163. function=DeltaFunctionCall(
  164. arguments=diff
  165. ).model_dump(exclude_none=True),
  166. )
  167. ]
  168. )
  169. self.streamed_args_for_tool[
  170. self.current_tool_id
  171. ] += diff
  172. else:
  173. delta = None
  174. else:
  175. delta = None
  176. # re-set stuff pertaining to progress in the current tool
  177. self.current_tool_id = len(tool_call_arr) - 1
  178. self.current_tool_name_sent = False
  179. self.streamed_args_for_tool.append("")
  180. logger.debug(f"starting on new tool {self.current_tool_id}")
  181. return delta
  182. # case: update an existing tool - this is handled below
  183. # if the current tool name hasn't been sent, send if available
  184. # - otherwise send nothing
  185. if not self.current_tool_name_sent:
  186. function_name = current_tool_call.get("name")
  187. if function_name:
  188. delta = DeltaMessage(tool_calls=[
  189. DeltaToolCall(index=self.current_tool_id,
  190. type="function",
  191. id=f"chatcmpl-tool-{random_uuid()}",
  192. function=DeltaFunctionCall(
  193. name=function_name).model_dump(
  194. exclude_none=True))
  195. ])
  196. self.current_tool_name_sent = True
  197. else:
  198. delta = None
  199. # now we know we're on the same tool call and we're streaming
  200. # arguments
  201. else:
  202. prev_arguments = self.prev_tool_call_arr[
  203. self.current_tool_id
  204. ].get("arguments")
  205. cur_arguments = current_tool_call.get("arguments")
  206. new_text = delta_text.replace("'", '"')
  207. if not cur_arguments and not prev_arguments:
  208. delta = None
  209. elif not cur_arguments and prev_arguments:
  210. logger.error(
  211. "INVARIANT - impossible to have arguments reset "
  212. "mid-arguments"
  213. )
  214. delta = None
  215. elif cur_arguments and not prev_arguments:
  216. cur_arguments_json = json.dumps(cur_arguments)
  217. logger.debug(
  218. f"finding {new_text} in {cur_arguments_json}"
  219. )
  220. arguments_delta = cur_arguments_json[
  221. : cur_arguments_json.index(new_text) + len(new_text)
  222. ]
  223. logger.debug(
  224. f"First tokens in arguments received: {arguments_delta}"
  225. )
  226. delta = DeltaMessage(
  227. tool_calls=[
  228. DeltaToolCall(
  229. index=self.current_tool_id,
  230. function=DeltaFunctionCall(
  231. arguments=arguments_delta
  232. ).model_dump(exclude_none=True),
  233. )
  234. ]
  235. )
  236. self.streamed_args_for_tool[
  237. self.current_tool_id
  238. ] += arguments_delta
  239. elif cur_arguments and prev_arguments:
  240. cur_args_json = json.dumps(cur_arguments)
  241. prev_args_json = json.dumps(prev_arguments)
  242. logger.debug(
  243. f"Searching for diff between \n{cur_args_json}\n"
  244. f"{prev_args_json}"
  245. )
  246. argument_diff = extract_intermediate_diff(
  247. cur_args_json, prev_args_json
  248. )
  249. logger.debug(f"got arguments diff: {argument_diff}")
  250. delta = DeltaMessage(
  251. tool_calls=[
  252. DeltaToolCall(
  253. index=self.current_tool_id,
  254. function=DeltaFunctionCall(
  255. arguments=argument_diff
  256. ).model_dump(exclude_none=True),
  257. )
  258. ]
  259. )
  260. self.streamed_args_for_tool[
  261. self.current_tool_id
  262. ] += argument_diff
  263. else:
  264. # try parsing it with regular JSON - if it works we're
  265. # at the end, and we need to send the difference between
  266. # tokens streamed so far and the valid JSON
  267. delta = None
  268. # check to see if the name is defined and has been sent. if so,
  269. # stream the name - otherwise keep waiting
  270. # finish by setting old and returning None as base case
  271. self.prev_tool_call_arr = tool_call_arr
  272. return delta
  273. except Exception as e:
  274. logger.error(f"Error trying to handle streaming tool call: {e}")
  275. logger.debug(
  276. "Skipping chunk as a result of tool streaming extraction "
  277. "error"
  278. )
  279. return None