grammar.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. import collections
  2. import functools
  3. import weakref
  4. from copy import copy, deepcopy
  5. from dataclasses import dataclass, fields
  6. from typing import List, Optional, Set, Union
  7. import regex
  8. import torch
  9. from lark import Lark
  10. from lark.lexer import Pattern, PatternRE, PatternStr, Token
  11. from lark.parsers.lalr_interactive_parser import InteractiveParser
  12. from lark.parsers.lalr_parser_state import ParserState
  13. from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
  14. class FastParserState(ParserState):
  15. copy_memo = {}
  16. def __copy__(self):
  17. new_value_stack = []
  18. for value in self.value_stack:
  19. key = f"{id(self)}_{id(value)}"
  20. if key not in self.copy_memo:
  21. self.copy_memo[key] = deepcopy(value, self.copy_memo)
  22. new_value_stack.append(self.copy_memo[key])
  23. new_instance = type(self)(
  24. self.parse_conf,
  25. self.lexer,
  26. copy(self.state_stack),
  27. new_value_stack,
  28. )
  29. self.copy_memo[id(self)] = new_instance
  30. return new_instance
  31. class FastInteractiveParser(InteractiveParser):
  32. def __init__(self, *args, **kwargs):
  33. super().__init__(*args, **kwargs)
  34. self.parser_state = FastParserState(
  35. self.parser_state.parse_conf,
  36. self.parser_state.lexer,
  37. self.parser_state.state_stack,
  38. self.parser_state.value_stack,
  39. )
  40. self.hash_val = None
  41. def __hash__(self):
  42. if self.hash_val is None:
  43. self.hash_val = hash(tuple(self.parser_state.state_stack))
  44. return self.hash_val
  45. def __copy__(self):
  46. return type(self)(
  47. self.parser,
  48. copy(self.parser_state),
  49. copy(self.lexer_thread),
  50. )
  51. def get_pattern_validator(pattern: Pattern):
  52. """
  53. Accepts a pattern object, either lark.lexer.PatternStr or
  54. lark.lexer.PatternRE
  55. Returns a function which validates a complete or partial string
  56. Returns Tuple with 2 values
  57. - 0) The processed portion of the sequence (None if no match at all)
  58. - 1) None if doesn't complete terminal, "" if completes terminal with no
  59. remainder, or "remainder"
  60. """
  61. if isinstance(pattern, PatternRE):
  62. compiled_pattern = regex.compile(pattern.value)
  63. @functools.lru_cache(int(1e6))
  64. def get_re_matched_parts(seq):
  65. # match complete terminal, potentially with leftover seq
  66. complete_terminal_match = compiled_pattern.match(seq)
  67. if complete_terminal_match:
  68. spans = complete_terminal_match.spans()
  69. if spans:
  70. span = complete_terminal_match.spans()[0]
  71. if span[0] == 0:
  72. processed_seq = seq[:span[1]]
  73. remainder_seq = seq[span[1]:]
  74. return processed_seq, remainder_seq
  75. # match doesn't complete terminal, but the sequence is fully
  76. # allowed
  77. partial_terminal_match = compiled_pattern.fullmatch(seq,
  78. partial=True)
  79. if partial_terminal_match:
  80. return seq, None
  81. return None, None
  82. return get_re_matched_parts
  83. elif isinstance(pattern, PatternStr):
  84. base_str = pattern.value
  85. @functools.lru_cache(int(1e6))
  86. def get_str_matched_parts(seq):
  87. if seq.startswith(base_str):
  88. processed_seq = seq[:len(base_str)]
  89. remainder_seq = seq[len(base_str):]
  90. return processed_seq, remainder_seq
  91. elif base_str.startswith(seq):
  92. return seq, None
  93. else:
  94. return None, None
  95. return get_str_matched_parts
  96. else:
  97. raise TypeError(f"Invalid pattern type: {type(pattern)}")
  98. def method_lru_cache(*lru_args, **lru_kwargs):
  99. # https://stackoverflow.com/a/44078118
  100. def decorator(func):
  101. @functools.wraps(func)
  102. def wrapped_func(self, *args, **kwargs):
  103. self_weak = weakref.ref(self)
  104. @functools.wraps(func)
  105. @functools.lru_cache(*lru_args, **lru_kwargs)
  106. def cached_method(*args, **kwargs):
  107. return func(self_weak(), *args, **kwargs)
  108. setattr(self, func.__name__, cached_method)
  109. return cached_method(*args, **kwargs)
  110. return wrapped_func
  111. return decorator
  112. memoize_by_instance = method_lru_cache(int(1e7))
  113. class TrieNode:
  114. def __init__(self):
  115. self.children = {}
  116. self.is_end_of_word = False
  117. self.value = None
  118. class Trie:
  119. def __init__(self):
  120. self.root = TrieNode()
  121. def insert(self, key, value):
  122. node = self.root
  123. for char in key:
  124. if char not in node.children:
  125. node.children[char] = TrieNode()
  126. node = node.children[char]
  127. node.is_end_of_word = True
  128. node.value = value
  129. def get_best(self, word):
  130. node = self.root
  131. prefix = ""
  132. best_prefix = ""
  133. best_value = node.value
  134. for char in word:
  135. if char in node.children:
  136. node = node.children[char]
  137. prefix += char
  138. if node.is_end_of_word:
  139. best_prefix = prefix
  140. best_value = node.value
  141. else:
  142. break # break if char not in trie
  143. remainder = word[len(best_prefix):]
  144. return best_prefix, best_value, remainder
  145. @dataclass
  146. class IncrementalParserState:
  147. """
  148. Parsing utility which enforces uniqueness of
  149. - interactive parser state stack
  150. - incomplete `partial_token` string
  151. the state of the parser and the incomplete token comprise a unique parser
  152. state. Core function exposed is `self.step(new_seq)`
  153. - Returns a new IncrementalParserState based with new_seq applied
  154. Memoization strategy is
  155. - 1) Ensure uniqueness of (interactive_parser, partial_token)
  156. - 2) Cache class methods via `memoize_by_instance` which considers id(self)
  157. and fn arguments
  158. """
  159. # unique state key
  160. interactive_parser: FastInteractiveParser
  161. # function of key
  162. terminal_candidates: list
  163. # shared across instances
  164. _ignored_terms: set
  165. _seq_validator: dict
  166. _memo: dict
  167. _full_seq_trie: Trie
  168. def __repr__(self):
  169. class_name = self.__class__.__name__
  170. state_stack = self.interactive_parser.parser_state.state_stack
  171. return f"{class_name}({state_stack})"
  172. @classmethod
  173. @functools.lru_cache(1000)
  174. def from_grammar(cls, grammar: str, start: str):
  175. lark_parser = Lark(
  176. grammar,
  177. regex=True, # use `regex` not `re`
  178. start=start,
  179. parser="lalr",
  180. cache=True, # results in 2-3x faster loading
  181. )
  182. base_interactive_parser = lark_parser.parse_interactive()
  183. interactive_parser = FastInteractiveParser(
  184. base_interactive_parser.parser,
  185. base_interactive_parser.parser_state,
  186. base_interactive_parser.lexer_thread)
  187. interactive_parser.lexer_thread.state.text = ""
  188. _seq_validator = {(term.name): get_pattern_validator(term.pattern)
  189. for term in lark_parser.terminals}
  190. _seq_validator["$END"] = lambda seq: tuple(
  191. ["" if seq is None else None] * 2)
  192. parser = cls(interactive_parser=interactive_parser,
  193. terminal_candidates=None,
  194. _ignored_terms=set(lark_parser.lexer_conf.ignore),
  195. _seq_validator=_seq_validator,
  196. _memo={},
  197. _full_seq_trie=Trie())
  198. parser._full_seq_trie.insert("", parser)
  199. return parser
  200. def new(self, **kwargs):
  201. """Cached create now state"""
  202. parser_state_key = hash(kwargs["interactive_parser"])
  203. if parser_state_key in self._memo:
  204. return self._memo[parser_state_key]
  205. instance_dict = {f.name: getattr(self, f.name) for f in fields(self)}
  206. instance_dict.update(kwargs)
  207. inst = self.__class__(**instance_dict)
  208. self._memo[parser_state_key] = inst
  209. return inst
  210. def __getitem__(self, full_seq):
  211. """Get the parser state, given a full sequence"""
  212. # pylint: disable=unused-variable
  213. match_seq, parser, remainder_seq = self._full_seq_trie.get_best(
  214. full_seq)
  215. if parser is None:
  216. return
  217. if remainder_seq:
  218. result = parser.step(remainder_seq)
  219. if result is None:
  220. return None
  221. remainder_seq, parser = result
  222. processed_seq = full_seq
  223. if remainder_seq:
  224. processed_seq = processed_seq[:-len(remainder_seq)]
  225. self._full_seq_trie.insert(processed_seq, parser)
  226. return remainder_seq, parser
  227. @memoize_by_instance
  228. def step(self, new_seq: str):
  229. """
  230. - Construct extended (maybe-partial) token candidate
  231. - If complete match, create new-terminal incremented parser state
  232. - there is leftover from new_seq, recurse on the new parser
  233. - If partial matches,
  234. return new parser with updated partial token str and updated
  235. terminal candidates
  236. - If no partial matches, return None
  237. """
  238. if new_seq == "":
  239. return "", self
  240. best_terminal, processed_seq, remainder_seq = (
  241. self.get_best_matched_terminal(new_seq))
  242. # invalid
  243. if best_terminal is None:
  244. return None
  245. # candidate doesn't complete terminal
  246. elif remainder_seq is None:
  247. return processed_seq, self
  248. # candidate completes terminal
  249. else:
  250. new_parser = self._next_with_new_terminal(best_terminal)
  251. if remainder_seq == "":
  252. return "", new_parser
  253. else:
  254. return new_parser.step(remainder_seq)
  255. @memoize_by_instance
  256. def _next_with_new_terminal(self, terminal):
  257. if terminal in self._ignored_terms:
  258. new_interactive_parser = self.interactive_parser
  259. else:
  260. new_interactive_parser = self.get_stepped_parser_state(terminal)
  261. return self.new(
  262. interactive_parser=new_interactive_parser,
  263. terminal_candidates=None,
  264. )
  265. def get_best_matched_terminal(self, seq):
  266. for terminal in self.accepts():
  267. processed_seq, remainder_seq = self._seq_validator[terminal](seq)
  268. if processed_seq:
  269. return terminal, processed_seq, remainder_seq
  270. return None, None, None
  271. @memoize_by_instance
  272. def get_stepped_parser_state(self, new_token_str):
  273. ip = copy(self.interactive_parser)
  274. ip.feed_token(Token(new_token_str, ""))
  275. return ip
  276. @memoize_by_instance
  277. def accepts(self):
  278. return set(self.interactive_parser.accepts()) | self._ignored_terms
  279. @memoize_by_instance
  280. def allowed_terminals(self):
  281. if self.terminal_candidates is not None:
  282. return tuple(sorted(self.terminal_candidates))
  283. return tuple(sorted(self.accepts()))
  284. @memoize_by_instance
  285. def is_valid_next_seq(self, new_seq: Optional[str]):
  286. if new_seq is None:
  287. return "$END" in self.allowed_terminals()
  288. return self.step(new_seq) is not None
  289. class TokenVocab:
  290. """
  291. Normalized token vocabulary accounting for whitespace and multiple IDs
  292. per token
  293. - iter: iterate over normalized token strings
  294. - vocab[token_str]: return token id set
  295. """
  296. def __init__(self,
  297. tokenizer: Union[PreTrainedTokenizer,
  298. PreTrainedTokenizerFast],
  299. legal_chars: Optional[Set[str]] = None):
  300. self.norm_vocab = collections.defaultdict(set)
  301. for token_id in tokenizer.vocab.values():
  302. if token_id == tokenizer.eos_token_id:
  303. self.norm_vocab[None].add(token_id)
  304. continue
  305. bos_len = len(tokenizer.bos_token)
  306. norm_token = tokenizer.decode([tokenizer.bos_token_id,
  307. token_id])[bos_len:]
  308. if legal_chars is None or all(
  309. [char in legal_chars for char in norm_token]):
  310. self.norm_vocab[norm_token].add(token_id)
  311. def __iter__(self):
  312. return iter(self.norm_vocab)
  313. def __getitem__(self, tok_str):
  314. return self.norm_vocab[tok_str]
  315. class NextTokenValidator:
  316. def __init__(
  317. self,
  318. tokenizer,
  319. grammar: str,
  320. grammar_start: str = "start",
  321. legal_chars: Optional[set[str]] = None,
  322. ):
  323. self.tokenizer = tokenizer
  324. self.vocab = TokenVocab(tokenizer, legal_chars=legal_chars)
  325. self.root_parser = IncrementalParserState.from_grammar(
  326. grammar, grammar_start)
  327. def get_valid_next_token_strs(self, full_seq):
  328. """
  329. Generate valid token strings given the full sequence
  330. """
  331. result = self.root_parser[full_seq]
  332. if result is None:
  333. return []
  334. partial_term, parser = result
  335. for token in self.vocab:
  336. if token is None:
  337. if partial_term == "" and parser.is_valid_next_seq(token):
  338. yield None
  339. else:
  340. if parser.is_valid_next_seq(partial_term + token):
  341. yield token
  342. def get_valid_next_token_ids(self, full_seq):
  343. """
  344. Generate valid token ids given the full sequence
  345. """
  346. for tok_str in self.get_valid_next_token_strs(full_seq):
  347. yield from self.vocab[tok_str]
  348. class GrammarLogitsProcessor(NextTokenValidator):
  349. """
  350. Apply NextTokenValidator in __call__ and set excluded tokens logits to -inf
  351. """
  352. def __call__(self, logits: torch.Tensor,
  353. token_ids: List[List[int]]) -> None:
  354. for i in range(len(token_ids)):
  355. # get valid token IDs given prior tokens
  356. sequence = self.tokenizer.decode(token_ids[i])
  357. valid_token_ids = self.get_valid_next_token_ids(sequence)
  358. valid = torch.tensor(list(valid_token_ids), dtype=torch.long)
  359. # modify logits given valid token IDs
  360. N = len(logits[i])
  361. mask = torch.zeros(N, dtype=torch.bool)
  362. mask[valid] = True
  363. logits[i][~mask] = float("-inf")