import collections from copy import deepcopy, copy from dataclasses import dataclass, fields import functools import regex import torch from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from typing import Optional, List, Set, Union import weakref from lark import Lark from lark.parsers.lalr_interactive_parser import InteractiveParser from lark.parsers.lalr_parser_state import ParserState from lark.lexer import Token, Pattern, PatternStr, PatternRE class FastParserState(ParserState): copy_memo = {} def __copy__(self): new_value_stack = [] for value in self.value_stack: key = f"{id(self)}_{id(value)}" if key not in self.copy_memo: self.copy_memo[key] = deepcopy(value, self.copy_memo) new_value_stack.append(self.copy_memo[key]) new_instance = type(self)( self.parse_conf, self.lexer, copy(self.state_stack), new_value_stack, ) self.copy_memo[id(self)] = new_instance return new_instance class FastInteractiveParser(InteractiveParser): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.parser_state = FastParserState( self.parser_state.parse_conf, self.parser_state.lexer, self.parser_state.state_stack, self.parser_state.value_stack, ) self.hash_val = None def __hash__(self): if self.hash_val is None: self.hash_val = hash(tuple(self.parser_state.state_stack)) return self.hash_val def __copy__(self): return type(self)( self.parser, copy(self.parser_state), copy(self.lexer_thread), ) def get_pattern_validator(pattern: Pattern): """ Accepts a pattern object, either lark.lexer.PatternStr or lark.lexer.PatternRE Returns a function which validates a complete or partial string Returns Tuple with 2 values - 0) The processed portion of the sequence (None if no match at all) - 1) None if doesn't complete terminal, "" if completes terminal with no remainder, or "remainder" """ if isinstance(pattern, PatternRE): compiled_pattern = regex.compile(pattern.value) @functools.lru_cache(int(1e6)) def get_re_matched_parts(seq): # match complete terminal, potentially with leftover seq complete_terminal_match = compiled_pattern.match(seq) if complete_terminal_match: spans = complete_terminal_match.spans() if spans: span = complete_terminal_match.spans()[0] if span[0] == 0: processed_seq = seq[:span[1]] remainder_seq = seq[span[1]:] return processed_seq, remainder_seq # match doesn't complete terminal, but the sequence is fully # allowed partial_terminal_match = compiled_pattern.fullmatch(seq, partial=True) if partial_terminal_match: return seq, None return None, None return get_re_matched_parts elif isinstance(pattern, PatternStr): base_str = pattern.value @functools.lru_cache(int(1e6)) def get_str_matched_parts(seq): if seq.startswith(base_str): processed_seq = seq[:len(base_str)] remainder_seq = seq[len(base_str):] return processed_seq, remainder_seq elif base_str.startswith(seq): return seq, None else: return None, None return get_str_matched_parts else: raise TypeError(f"Invalid pattern type: {type(pattern)}") def method_lru_cache(*lru_args, **lru_kwargs): # https://stackoverflow.com/a/44078118 def decorator(func): @functools.wraps(func) def wrapped_func(self, *args, **kwargs): self_weak = weakref.ref(self) @functools.wraps(func) @functools.lru_cache(*lru_args, **lru_kwargs) def cached_method(*args, **kwargs): return func(self_weak(), *args, **kwargs) setattr(self, func.__name__, cached_method) return cached_method(*args, **kwargs) return wrapped_func return decorator memoize_by_instance = method_lru_cache(int(1e7)) class TrieNode: def __init__(self): self.children = {} self.is_end_of_word = False self.value = None class Trie: def __init__(self): self.root = TrieNode() def insert(self, key, value): node = self.root for char in key: if char not in node.children: node.children[char] = TrieNode() node = node.children[char] node.is_end_of_word = True node.value = value def get_best(self, word): node = self.root prefix = "" best_prefix = "" best_value = node.value for char in word: if char in node.children: node = node.children[char] prefix += char if node.is_end_of_word: best_prefix = prefix best_value = node.value else: break # break if char not in trie remainder = word[len(best_prefix):] return best_prefix, best_value, remainder @dataclass class IncrementalParserState: """ Parsing utility which enforces uniqueness of - interactive parser state stack - incomplete `partial_token` string the state of the parser and the incomplete token comprise a unique parser state. Core function exposed is `self.step(new_seq)` - Returns a new IncrementalParserState based with new_seq applied Memoization strategy is - 1) Ensure uniqueness of (interactive_parser, partial_token) - 2) Cache class methods via `memoize_by_instance` which considers id(self) and fn arguments """ # unique state key interactive_parser: FastInteractiveParser # function of key terminal_candidates: list # shared across instances _ignored_terms: set _seq_validator: dict _memo: dict _full_seq_trie: Trie def __repr__(self): class_name = self.__class__.__name__ state_stack = self.interactive_parser.parser_state.state_stack return f"{class_name}({state_stack})" @classmethod @functools.lru_cache(1000) def from_grammar(cls, grammar: str, start: str): lark_parser = Lark( grammar, regex=True, # use `regex` not `re` start=start, parser="lalr", cache=True, # results in 2-3x faster loading ) base_interactive_parser = lark_parser.parse_interactive() interactive_parser = FastInteractiveParser( base_interactive_parser.parser, base_interactive_parser.parser_state, base_interactive_parser.lexer_thread) interactive_parser.lexer_thread.state.text = "" _seq_validator = {(term.name): get_pattern_validator(term.pattern) for term in lark_parser.terminals} _seq_validator["$END"] = lambda seq: tuple( ["" if seq is None else None] * 2) parser = cls(interactive_parser=interactive_parser, terminal_candidates=None, _ignored_terms=set(lark_parser.lexer_conf.ignore), _seq_validator=_seq_validator, _memo={}, _full_seq_trie=Trie()) parser._full_seq_trie.insert("", parser) return parser def new(self, **kwargs): """Cached create now state""" parser_state_key = hash(kwargs["interactive_parser"]) if parser_state_key in self._memo: return self._memo[parser_state_key] instance_dict = {f.name: getattr(self, f.name) for f in fields(self)} instance_dict.update(kwargs) inst = self.__class__(**instance_dict) self._memo[parser_state_key] = inst return inst def __getitem__(self, full_seq): """Get the parser state, given a full sequence""" # pylint: disable=unused-variable match_seq, parser, remainder_seq = self._full_seq_trie.get_best( full_seq) if parser is None: return if remainder_seq: result = parser.step(remainder_seq) if result is None: return None remainder_seq, parser = result processed_seq = full_seq if remainder_seq: processed_seq = processed_seq[:-len(remainder_seq)] self._full_seq_trie.insert(processed_seq, parser) return remainder_seq, parser @memoize_by_instance def step(self, new_seq: str): """ - Construct extended (maybe-partial) token candidate - If complete match, create new-terminal incremented parser state - there is leftover from new_seq, recurse on the new parser - If partial matches, return new parser with updated partial token str and updated terminal candidates - If no partial matches, return None """ if new_seq == "": return "", self best_terminal, processed_seq, remainder_seq = ( self.get_best_matched_terminal(new_seq)) # invalid if best_terminal is None: return None # candidate doesn't complete terminal elif remainder_seq is None: return processed_seq, self # candidate completes terminal else: new_parser = self._next_with_new_terminal(best_terminal) if remainder_seq == "": return "", new_parser else: return new_parser.step(remainder_seq) @memoize_by_instance def _next_with_new_terminal(self, terminal): if terminal in self._ignored_terms: new_interactive_parser = self.interactive_parser else: new_interactive_parser = self.get_stepped_parser_state(terminal) return self.new( interactive_parser=new_interactive_parser, terminal_candidates=None, ) def get_best_matched_terminal(self, seq): for terminal in self.accepts(): processed_seq, remainder_seq = self._seq_validator[terminal](seq) if processed_seq: return terminal, processed_seq, remainder_seq return None, None, None @memoize_by_instance def get_stepped_parser_state(self, new_token_str): ip = copy(self.interactive_parser) ip.feed_token(Token(new_token_str, "")) return ip @memoize_by_instance def accepts(self): return set(self.interactive_parser.accepts()) | self._ignored_terms @memoize_by_instance def allowed_terminals(self): if self.terminal_candidates is not None: return tuple(sorted(self.terminal_candidates)) return tuple(sorted(self.accepts())) @memoize_by_instance def is_valid_next_seq(self, new_seq: Optional[str]): if new_seq is None: return "$END" in self.allowed_terminals() return self.step(new_seq) is not None class TokenVocab: """ Normalized token vocabulary accounting for whitespace and multiple IDs per token - iter: iterate over normalized token strings - vocab[token_str]: return token id set """ def __init__(self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], legal_chars: Optional[Set[str]] = None): self.norm_vocab = collections.defaultdict(set) for token_id in tokenizer.vocab.values(): if token_id == tokenizer.eos_token_id: self.norm_vocab[None].add(token_id) continue bos_len = len(tokenizer.bos_token) norm_token = tokenizer.decode([tokenizer.bos_token_id, token_id])[bos_len:] if legal_chars is None or all( [char in legal_chars for char in norm_token]): self.norm_vocab[norm_token].add(token_id) def __iter__(self): return iter(self.norm_vocab) def __getitem__(self, tok_str): return self.norm_vocab[tok_str] class NextTokenValidator: def __init__( self, tokenizer, grammar: str, grammar_start: str = "start", legal_chars: Optional[set[str]] = None, ): self.tokenizer = tokenizer self.vocab = TokenVocab(tokenizer, legal_chars=legal_chars) self.root_parser = IncrementalParserState.from_grammar( grammar, grammar_start) def get_valid_next_token_strs(self, full_seq): """ Generate valid token strings given the full sequence """ result = self.root_parser[full_seq] if result is None: return [] partial_term, parser = result for token in self.vocab: if token is None: if partial_term == "" and parser.is_valid_next_seq(token): yield None else: if parser.is_valid_next_seq(partial_term + token): yield token def get_valid_next_token_ids(self, full_seq): """ Generate valid token ids given the full sequence """ for tok_str in self.get_valid_next_token_strs(full_seq): yield from self.vocab[tok_str] class GrammarLogitsProcessor(NextTokenValidator): """ Apply NextTokenValidator in __call__ and set excluded tokens logits to -inf """ def __call__(self, logits: torch.Tensor, token_ids: List[List[int]]) -> None: for i in range(len(token_ids)): # get valid token IDs given prior tokens sequence = self.tokenizer.decode(token_ids[i]) valid_token_ids = self.get_valid_next_token_ids(sequence) valid = torch.tensor(list(valid_token_ids), dtype=torch.long) # modify logits given valid token IDs N = len(logits[i]) mask = torch.zeros(N, dtype=torch.bool) mask[valid] = True logits[i][~mask] = float("-inf")