hermes_tool_parser.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. import json
  2. import re
  3. from typing import Dict, List, Sequence, Union
  4. import partial_json_parser
  5. from loguru import logger
  6. from partial_json_parser.core.options import Allow
  7. from aphrodite.common.utils import random_uuid
  8. from aphrodite.endpoints.openai.protocol import (DeltaFunctionCall,
  9. DeltaMessage, DeltaToolCall,
  10. ExtractedToolCallInformation,
  11. FunctionCall, ToolCall)
  12. from aphrodite.endpoints.openai.tool_parsers.abstract_tool_parser import (
  13. ToolParser)
  14. from aphrodite.endpoints.openai.tool_parsers.utils import (
  15. extract_intermediate_diff)
  16. from aphrodite.transformers_utils.tokenizer import (AnyTokenizer,
  17. MistralTokenizer)
  18. class Hermes2ProToolParser(ToolParser):
  19. def __init__(self, tokenizer: AnyTokenizer):
  20. super().__init__(tokenizer)
  21. if isinstance(self.model_tokenizer, MistralTokenizer):
  22. logger.error("Detected Mistral tokenizer when using a Hermes model")
  23. self.model_tokenizer = self.model_tokenizer.tokenizer
  24. self.current_tool_name_sent: bool = False
  25. self.prev_tool_call_arr: List[Dict] = []
  26. self.current_tool_id: int = -1
  27. self.streamed_args_for_tool: List[
  28. str
  29. ] = [] # map what has been streamed for each tool so far to a list
  30. self.tool_call_start_token: str = "<tool_call>"
  31. self.tool_call_end_token: str = "</tool_call>"
  32. self.tool_call_regex = re.compile(
  33. r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL
  34. )
  35. self.scratch_pad_regex = re.compile(
  36. r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL
  37. )
  38. if not self.model_tokenizer:
  39. raise ValueError(
  40. "The model tokenizer must be passed to the ToolParser "
  41. "constructor during construction."
  42. )
  43. self.tool_call_start_token_id: int = self.model_tokenizer.vocab[
  44. self.tool_call_start_token
  45. ]
  46. self.tool_call_end_token_id: int = self.model_tokenizer.vocab[
  47. self.tool_call_end_token
  48. ]
  49. if not self.tool_call_start_token_id or not self.tool_call_end_token_id:
  50. raise RuntimeError(
  51. "Hermes 2 Pro Tool parser could not locate tool call start/end "
  52. "tokens in the tokenizer!"
  53. )
  54. def extract_tool_calls(
  55. self, model_output: str
  56. ) -> ExtractedToolCallInformation:
  57. # sanity check; avoid unnecessary processing
  58. if self.tool_call_start_token not in model_output:
  59. return ExtractedToolCallInformation(
  60. tools_called=False, tool_calls=[], content=model_output
  61. )
  62. else:
  63. try:
  64. # there are two possible captures - between tags, or between a
  65. # tag and end-of-string so the result of
  66. # findall is an array of tuples where one is a function call and
  67. # the other is None
  68. function_call_tuples = self.tool_call_regex.findall(
  69. model_output
  70. )
  71. # load the JSON, and then use it to build the Function and
  72. # Tool Call
  73. raw_function_calls = [
  74. json.loads(match[0] if match[0] else match[1])
  75. for match in function_call_tuples
  76. ]
  77. tool_calls = [
  78. ToolCall(
  79. type="function",
  80. function=FunctionCall(
  81. name=function_call["name"],
  82. # function call args are JSON but as a string
  83. arguments=json.dumps(function_call["arguments"]),
  84. ),
  85. )
  86. for function_call in raw_function_calls
  87. ]
  88. content = model_output[
  89. : model_output.find(self.tool_call_start_token)
  90. ]
  91. return ExtractedToolCallInformation(
  92. tools_called=True,
  93. tool_calls=tool_calls,
  94. content=content if content else None,
  95. )
  96. except Exception as e:
  97. logger.error(f"Error in extracting tool call from response {e}")
  98. return ExtractedToolCallInformation(
  99. tools_called=False, tool_calls=[], content=model_output
  100. )
  101. def extract_tool_calls_streaming(
  102. self,
  103. previous_text: str,
  104. current_text: str,
  105. delta_text: str,
  106. previous_token_ids: Sequence[int],
  107. current_token_ids: Sequence[int],
  108. delta_token_ids: Sequence[int],
  109. ) -> Union[DeltaMessage, None]:
  110. logger.debug(f"delta_text: {delta_text}")
  111. logger.debug(f"delta_token_ids: {delta_token_ids}")
  112. # check to see if we should be streaming a tool call - is there a
  113. if self.tool_call_start_token_id not in current_token_ids:
  114. logger.debug("No tool call tokens found!")
  115. return DeltaMessage(content=delta_text)
  116. try:
  117. # figure out where we are in the parsing by counting tool call
  118. # start & end tags
  119. prev_tool_start_count = previous_token_ids.count(
  120. self.tool_call_start_token_id
  121. )
  122. prev_tool_end_count = previous_token_ids.count(
  123. self.tool_call_end_token_id
  124. )
  125. cur_tool_start_count = current_token_ids.count(
  126. self.tool_call_start_token_id
  127. )
  128. cur_tool_end_count = current_token_ids.count(
  129. self.tool_call_end_token_id
  130. )
  131. # case: if we're generating text, OR rounding out a tool call
  132. if (
  133. cur_tool_start_count == cur_tool_end_count
  134. and prev_tool_end_count == cur_tool_end_count
  135. ):
  136. logger.debug("Generating text content! skipping tool parsing.")
  137. if delta_text != self.tool_call_end_token:
  138. return DeltaMessage(content=delta_text)
  139. # case: if tool open & close tag counts don't match, we're doing
  140. # imaginary "else" block here
  141. # something with tools with this diff.
  142. # flags for partial JSON parting. exported constants from
  143. # "Allow" are handled via BIT MASK
  144. flags = (
  145. Allow.ALL
  146. if self.current_tool_name_sent
  147. else Allow.ALL & ~Allow.STR
  148. )
  149. # case -- we're starting a new tool call
  150. if (
  151. cur_tool_start_count > cur_tool_end_count
  152. and cur_tool_start_count > prev_tool_start_count
  153. ):
  154. if len(delta_token_ids) > 1:
  155. tool_call_portion = current_text.split(
  156. self.tool_call_start_token
  157. )[-1]
  158. else:
  159. tool_call_portion = None
  160. delta = None
  161. text_portion = None
  162. # set cursors and state appropriately
  163. self.current_tool_id += 1
  164. self.current_tool_name_sent = False
  165. self.streamed_args_for_tool.append("")
  166. logger.debug(f"Starting on a new tool {self.current_tool_id}")
  167. # case -- we're updating an existing tool call
  168. elif (
  169. cur_tool_start_count > cur_tool_end_count
  170. and cur_tool_start_count == prev_tool_start_count
  171. ):
  172. # get the portion of the text that's the tool call
  173. tool_call_portion = current_text.split(
  174. self.tool_call_start_token
  175. )[-1]
  176. text_portion = None
  177. # case -- the current tool call is being closed.
  178. elif (
  179. cur_tool_start_count == cur_tool_end_count
  180. and cur_tool_end_count > prev_tool_end_count
  181. ):
  182. diff = self.prev_tool_call_arr[self.current_tool_id].get(
  183. "arguments"
  184. )
  185. if diff:
  186. diff = json.dumps(diff).replace(
  187. self.streamed_args_for_tool[self.current_tool_id], ""
  188. )
  189. logger.debug(
  190. f"Finishing tool and found diff that had not "
  191. f"been streamed yet: {diff}"
  192. )
  193. self.streamed_args_for_tool[self.current_tool_id] += diff
  194. return DeltaMessage(
  195. tool_calls=[
  196. DeltaToolCall(
  197. index=self.current_tool_id,
  198. function=DeltaFunctionCall(
  199. arguments=diff
  200. ).model_dump(exclude_none=True),
  201. )
  202. ]
  203. )
  204. # case -- otherwise we're just generating text
  205. else:
  206. text = delta_text.replace(self.tool_call_start_token, "")
  207. text = text.replace(self.tool_call_end_token, "")
  208. delta = DeltaMessage(tool_calls=[], content=text)
  209. return delta
  210. try:
  211. current_tool_call = (
  212. partial_json_parser.loads(tool_call_portion or "{}", flags)
  213. if tool_call_portion
  214. else None
  215. )
  216. logger.debug(f"Parsed tool call {current_tool_call}")
  217. except partial_json_parser.core.exceptions.MalformedJSON:
  218. logger.debug("not enough tokens to parse into JSON yet")
  219. return None
  220. # case - we haven't sent the tool name yet. If it's available, send
  221. # it. otherwise, wait until it's available.
  222. if not self.current_tool_name_sent:
  223. function_name: Union[str, None] = current_tool_call.get("name")
  224. if function_name:
  225. self.current_tool_name_sent = True
  226. return DeltaMessage(tool_calls=[
  227. DeltaToolCall(index=self.current_tool_id,
  228. type="function",
  229. id=f"chatcmpl-tool-{random_uuid()}",
  230. function=DeltaFunctionCall(
  231. name=function_name).model_dump(
  232. exclude_none=True))
  233. ])
  234. else:
  235. return None
  236. # case -- otherwise, send the tool call delta
  237. # if the tool call portion is None, send the delta as text
  238. if tool_call_portion is None:
  239. # if there's text but not tool calls, send that -
  240. # otherwise None to skip chunk
  241. delta = (
  242. DeltaMessage(content=delta_text)
  243. if text_portion is not None
  244. else None
  245. )
  246. return delta
  247. # now, the nitty-gritty of tool calls
  248. # now we have the portion to parse as tool call.
  249. logger.debug(
  250. "Trying to parse current tool call with ID "
  251. f"{self.current_tool_id}"
  252. )
  253. # if we're starting a new tool call, push an empty object in as
  254. # a placeholder for the arguments
  255. if len(self.prev_tool_call_arr) <= self.current_tool_id:
  256. self.prev_tool_call_arr.append({})
  257. # main logic for tool parsing here - compare prev. partially-parsed
  258. # JSON to the current partially-parsed JSON
  259. prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
  260. "arguments"
  261. )
  262. cur_arguments = current_tool_call.get("arguments")
  263. logger.debug(f"diffing old arguments: {prev_arguments}")
  264. logger.debug(f"against new ones: {cur_arguments}")
  265. # case -- no arguments have been created yet. skip sending a delta.
  266. if not cur_arguments and not prev_arguments:
  267. logger.debug(f"Skipping text {delta_text} - no arguments")
  268. delta = None
  269. # case -- prev arguments are defined, but non are now.
  270. # probably impossible, but not a fatal error - just keep going
  271. elif not cur_arguments and prev_arguments:
  272. logger.error(
  273. "should be impossible to have arguments reset "
  274. "mid-call. skipping streaming anything."
  275. )
  276. delta = None
  277. # case -- we now have the first info about arguments available from
  278. # autocompleting the JSON
  279. elif cur_arguments and not prev_arguments:
  280. cur_arguments_json = json.dumps(cur_arguments)
  281. logger.debug(
  282. f"finding {delta_text} in {cur_arguments_json}"
  283. )
  284. # get the location where previous args differ from current
  285. args_delta_start_loc = cur_arguments_json.index(
  286. delta_text
  287. ) + len(delta_text)
  288. # use that to find the actual delta
  289. arguments_delta = cur_arguments_json[:args_delta_start_loc]
  290. logger.debug(
  291. f"First tokens in arguments received: {arguments_delta}"
  292. )
  293. delta = DeltaMessage(
  294. tool_calls=[
  295. DeltaToolCall(
  296. index=self.current_tool_id,
  297. function=DeltaFunctionCall(
  298. arguments=arguments_delta
  299. ).model_dump(exclude_none=True),
  300. )
  301. ]
  302. )
  303. self.streamed_args_for_tool[
  304. self.current_tool_id
  305. ] += arguments_delta
  306. # last case -- we have an update to existing arguments.
  307. elif cur_arguments and prev_arguments:
  308. cur_args_json = json.dumps(cur_arguments)
  309. prev_args_json = json.dumps(prev_arguments)
  310. logger.debug(f"Searching for diff between\n{cur_args_json}")
  311. logger.debug(f"and\n{prev_args_json}")
  312. argument_diff = extract_intermediate_diff(
  313. cur_args_json, prev_args_json
  314. )
  315. logger.debug(f"got argument diff {argument_diff}")
  316. delta = DeltaMessage(
  317. tool_calls=[
  318. DeltaToolCall(
  319. index=self.current_tool_id,
  320. function=DeltaFunctionCall(
  321. arguments=argument_diff
  322. ).model_dump(exclude_none=True),
  323. )
  324. ]
  325. )
  326. self.streamed_args_for_tool[
  327. self.current_tool_id
  328. ] += argument_diff
  329. # handle saving the state for the current tool into
  330. # the "prev" list for use in diffing for the next iteration
  331. if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
  332. self.prev_tool_call_arr[
  333. self.current_tool_id
  334. ] = current_tool_call
  335. else:
  336. self.prev_tool_call_arr.append(current_tool_call)
  337. return delta
  338. except Exception as e:
  339. logger.error(f"Error trying to handle streaming tool call: {e}")
  340. return None # do not stream a delta. skip this token ID.