123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445 |
- import collections
- import functools
- import weakref
- from copy import copy, deepcopy
- from dataclasses import dataclass, fields
- from typing import List, Optional, Set, Union
- import regex
- import torch
- from lark import Lark
- from lark.lexer import Pattern, PatternRE, PatternStr, Token
- from lark.parsers.lalr_interactive_parser import InteractiveParser
- from lark.parsers.lalr_parser_state import ParserState
- from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
- class FastParserState(ParserState):
- copy_memo = {}
- def __copy__(self):
- new_value_stack = []
- for value in self.value_stack:
- key = f"{id(self)}_{id(value)}"
- if key not in self.copy_memo:
- self.copy_memo[key] = deepcopy(value, self.copy_memo)
- new_value_stack.append(self.copy_memo[key])
- new_instance = type(self)(
- self.parse_conf,
- self.lexer,
- copy(self.state_stack),
- new_value_stack,
- )
- self.copy_memo[id(self)] = new_instance
- return new_instance
- class FastInteractiveParser(InteractiveParser):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.parser_state = FastParserState(
- self.parser_state.parse_conf,
- self.parser_state.lexer,
- self.parser_state.state_stack,
- self.parser_state.value_stack,
- )
- self.hash_val = None
- def __hash__(self):
- if self.hash_val is None:
- self.hash_val = hash(tuple(self.parser_state.state_stack))
- return self.hash_val
- def __copy__(self):
- return type(self)(
- self.parser,
- copy(self.parser_state),
- copy(self.lexer_thread),
- )
- def get_pattern_validator(pattern: Pattern):
- """
- Accepts a pattern object, either lark.lexer.PatternStr or
- lark.lexer.PatternRE
- Returns a function which validates a complete or partial string
- Returns Tuple with 2 values
- - 0) The processed portion of the sequence (None if no match at all)
- - 1) None if doesn't complete terminal, "" if completes terminal with no
- remainder, or "remainder"
- """
- if isinstance(pattern, PatternRE):
- compiled_pattern = regex.compile(pattern.value)
- @functools.lru_cache(int(1e6))
- def get_re_matched_parts(seq):
- # match complete terminal, potentially with leftover seq
- complete_terminal_match = compiled_pattern.match(seq)
- if complete_terminal_match:
- spans = complete_terminal_match.spans()
- if spans:
- span = complete_terminal_match.spans()[0]
- if span[0] == 0:
- processed_seq = seq[:span[1]]
- remainder_seq = seq[span[1]:]
- return processed_seq, remainder_seq
- # match doesn't complete terminal, but the sequence is fully
- # allowed
- partial_terminal_match = compiled_pattern.fullmatch(seq,
- partial=True)
- if partial_terminal_match:
- return seq, None
- return None, None
- return get_re_matched_parts
- elif isinstance(pattern, PatternStr):
- base_str = pattern.value
- @functools.lru_cache(int(1e6))
- def get_str_matched_parts(seq):
- if seq.startswith(base_str):
- processed_seq = seq[:len(base_str)]
- remainder_seq = seq[len(base_str):]
- return processed_seq, remainder_seq
- elif base_str.startswith(seq):
- return seq, None
- else:
- return None, None
- return get_str_matched_parts
- else:
- raise TypeError(f"Invalid pattern type: {type(pattern)}")
- def method_lru_cache(*lru_args, **lru_kwargs):
- # https://stackoverflow.com/a/44078118
- def decorator(func):
- @functools.wraps(func)
- def wrapped_func(self, *args, **kwargs):
- self_weak = weakref.ref(self)
- @functools.wraps(func)
- @functools.lru_cache(*lru_args, **lru_kwargs)
- def cached_method(*args, **kwargs):
- return func(self_weak(), *args, **kwargs)
- setattr(self, func.__name__, cached_method)
- return cached_method(*args, **kwargs)
- return wrapped_func
- return decorator
- memoize_by_instance = method_lru_cache(int(1e7))
- class TrieNode:
- def __init__(self):
- self.children = {}
- self.is_end_of_word = False
- self.value = None
- class Trie:
- def __init__(self):
- self.root = TrieNode()
- def insert(self, key, value):
- node = self.root
- for char in key:
- if char not in node.children:
- node.children[char] = TrieNode()
- node = node.children[char]
- node.is_end_of_word = True
- node.value = value
- def get_best(self, word):
- node = self.root
- prefix = ""
- best_prefix = ""
- best_value = node.value
- for char in word:
- if char in node.children:
- node = node.children[char]
- prefix += char
- if node.is_end_of_word:
- best_prefix = prefix
- best_value = node.value
- else:
- break # break if char not in trie
- remainder = word[len(best_prefix):]
- return best_prefix, best_value, remainder
- @dataclass
- class IncrementalParserState:
- """
- Parsing utility which enforces uniqueness of
- - interactive parser state stack
- - incomplete `partial_token` string
- the state of the parser and the incomplete token comprise a unique parser
- state. Core function exposed is `self.step(new_seq)`
- - Returns a new IncrementalParserState based with new_seq applied
- Memoization strategy is
- - 1) Ensure uniqueness of (interactive_parser, partial_token)
- - 2) Cache class methods via `memoize_by_instance` which considers id(self)
- and fn arguments
- """
- # unique state key
- interactive_parser: FastInteractiveParser
- # function of key
- terminal_candidates: list
- # shared across instances
- _ignored_terms: set
- _seq_validator: dict
- _memo: dict
- _full_seq_trie: Trie
- def __repr__(self):
- class_name = self.__class__.__name__
- state_stack = self.interactive_parser.parser_state.state_stack
- return f"{class_name}({state_stack})"
- @classmethod
- @functools.lru_cache(1000)
- def from_grammar(cls, grammar: str, start: str):
- lark_parser = Lark(
- grammar,
- regex=True, # use `regex` not `re`
- start=start,
- parser="lalr",
- cache=True, # results in 2-3x faster loading
- )
- base_interactive_parser = lark_parser.parse_interactive()
- interactive_parser = FastInteractiveParser(
- base_interactive_parser.parser,
- base_interactive_parser.parser_state,
- base_interactive_parser.lexer_thread)
- interactive_parser.lexer_thread.state.text = ""
- _seq_validator = {(term.name): get_pattern_validator(term.pattern)
- for term in lark_parser.terminals}
- _seq_validator["$END"] = lambda seq: tuple(
- ["" if seq is None else None] * 2)
- parser = cls(interactive_parser=interactive_parser,
- terminal_candidates=None,
- _ignored_terms=set(lark_parser.lexer_conf.ignore),
- _seq_validator=_seq_validator,
- _memo={},
- _full_seq_trie=Trie())
- parser._full_seq_trie.insert("", parser)
- return parser
- def new(self, **kwargs):
- """Cached create now state"""
- parser_state_key = hash(kwargs["interactive_parser"])
- if parser_state_key in self._memo:
- return self._memo[parser_state_key]
- instance_dict = {f.name: getattr(self, f.name) for f in fields(self)}
- instance_dict.update(kwargs)
- inst = self.__class__(**instance_dict)
- self._memo[parser_state_key] = inst
- return inst
- def __getitem__(self, full_seq):
- """Get the parser state, given a full sequence"""
- # pylint: disable=unused-variable
- match_seq, parser, remainder_seq = self._full_seq_trie.get_best(
- full_seq)
- if parser is None:
- return
- if remainder_seq:
- result = parser.step(remainder_seq)
- if result is None:
- return None
- remainder_seq, parser = result
- processed_seq = full_seq
- if remainder_seq:
- processed_seq = processed_seq[:-len(remainder_seq)]
- self._full_seq_trie.insert(processed_seq, parser)
- return remainder_seq, parser
- @memoize_by_instance
- def step(self, new_seq: str):
- """
- - Construct extended (maybe-partial) token candidate
- - If complete match, create new-terminal incremented parser state
- - there is leftover from new_seq, recurse on the new parser
- - If partial matches,
- return new parser with updated partial token str and updated
- terminal candidates
- - If no partial matches, return None
- """
- if new_seq == "":
- return "", self
- best_terminal, processed_seq, remainder_seq = (
- self.get_best_matched_terminal(new_seq))
- # invalid
- if best_terminal is None:
- return None
- # candidate doesn't complete terminal
- elif remainder_seq is None:
- return processed_seq, self
- # candidate completes terminal
- else:
- new_parser = self._next_with_new_terminal(best_terminal)
- if remainder_seq == "":
- return "", new_parser
- else:
- return new_parser.step(remainder_seq)
- @memoize_by_instance
- def _next_with_new_terminal(self, terminal):
- if terminal in self._ignored_terms:
- new_interactive_parser = self.interactive_parser
- else:
- new_interactive_parser = self.get_stepped_parser_state(terminal)
- return self.new(
- interactive_parser=new_interactive_parser,
- terminal_candidates=None,
- )
- def get_best_matched_terminal(self, seq):
- for terminal in self.accepts():
- processed_seq, remainder_seq = self._seq_validator[terminal](seq)
- if processed_seq:
- return terminal, processed_seq, remainder_seq
- return None, None, None
- @memoize_by_instance
- def get_stepped_parser_state(self, new_token_str):
- ip = copy(self.interactive_parser)
- ip.feed_token(Token(new_token_str, ""))
- return ip
- @memoize_by_instance
- def accepts(self):
- return set(self.interactive_parser.accepts()) | self._ignored_terms
- @memoize_by_instance
- def allowed_terminals(self):
- if self.terminal_candidates is not None:
- return tuple(sorted(self.terminal_candidates))
- return tuple(sorted(self.accepts()))
- @memoize_by_instance
- def is_valid_next_seq(self, new_seq: Optional[str]):
- if new_seq is None:
- return "$END" in self.allowed_terminals()
- return self.step(new_seq) is not None
- class TokenVocab:
- """
- Normalized token vocabulary accounting for whitespace and multiple IDs
- per token
- - iter: iterate over normalized token strings
- - vocab[token_str]: return token id set
- """
- def __init__(self,
- tokenizer: Union[PreTrainedTokenizer,
- PreTrainedTokenizerFast],
- legal_chars: Optional[Set[str]] = None):
- self.norm_vocab = collections.defaultdict(set)
- for token_id in tokenizer.vocab.values():
- if token_id == tokenizer.eos_token_id:
- self.norm_vocab[None].add(token_id)
- continue
- bos_len = len(tokenizer.bos_token)
- norm_token = tokenizer.decode([tokenizer.bos_token_id,
- token_id])[bos_len:]
- if legal_chars is None or all(
- [char in legal_chars for char in norm_token]):
- self.norm_vocab[norm_token].add(token_id)
- def __iter__(self):
- return iter(self.norm_vocab)
- def __getitem__(self, tok_str):
- return self.norm_vocab[tok_str]
- class NextTokenValidator:
- def __init__(
- self,
- tokenizer,
- grammar: str,
- grammar_start: str = "start",
- legal_chars: Optional[set[str]] = None,
- ):
- self.tokenizer = tokenizer
- self.vocab = TokenVocab(tokenizer, legal_chars=legal_chars)
- self.root_parser = IncrementalParserState.from_grammar(
- grammar, grammar_start)
- def get_valid_next_token_strs(self, full_seq):
- """
- Generate valid token strings given the full sequence
- """
- result = self.root_parser[full_seq]
- if result is None:
- return []
- partial_term, parser = result
- for token in self.vocab:
- if token is None:
- if partial_term == "" and parser.is_valid_next_seq(token):
- yield None
- else:
- if parser.is_valid_next_seq(partial_term + token):
- yield token
- def get_valid_next_token_ids(self, full_seq):
- """
- Generate valid token ids given the full sequence
- """
- for tok_str in self.get_valid_next_token_strs(full_seq):
- yield from self.vocab[tok_str]
- class GrammarLogitsProcessor(NextTokenValidator):
- """
- Apply NextTokenValidator in __call__ and set excluded tokens logits to -inf
- """
- def __call__(self, logits: torch.Tensor,
- token_ids: List[List[int]]) -> None:
- for i in range(len(token_ids)):
- # get valid token IDs given prior tokens
- sequence = self.tokenizer.decode(token_ids[i])
- valid_token_ids = self.get_valid_next_token_ids(sequence)
- valid = torch.tensor(list(valid_token_ids), dtype=torch.long)
- # modify logits given valid token IDs
- N = len(logits[i])
- mask = torch.zeros(N, dtype=torch.bool)
- mask[valid] = True
- logits[i][~mask] = float("-inf")
|