Ver código fonte

feat: grammar support (#206)

* add grammars

* guess that doesnt work

* typo

* formatting
AlpinDale 1 ano atrás
pai
commit
0adab894fe

+ 1 - 0
.pylintrc

@@ -81,6 +81,7 @@ disable=abstract-method,
         import-self,
         import-star-module-level,
         import-outside-toplevel,
+        use-a-generator,
         inconsistent-return-statements,
         input-builtin,
         intern-builtin,

+ 470 - 0
aphrodite/common/grammar.py

@@ -0,0 +1,470 @@
+import collections
+from copy import deepcopy, copy
+from dataclasses import dataclass, fields
+import functools
+import regex
+import torch
+from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
+from typing import Optional, List, Set, Union
+import weakref
+
+import ray
+
+from lark import Lark
+from lark.parsers.lalr_interactive_parser import InteractiveParser
+from lark.parsers.lalr_parser_state import ParserState
+from lark.lexer import Token, Pattern, PatternStr, PatternRE
+
+
+class FastParserState(ParserState):
+    copy_memo = {}
+
+    def __copy__(self):
+        new_value_stack = []
+        for value in self.value_stack:
+            key = f"{id(self)}_{id(value)}"
+            if key not in self.copy_memo:
+                self.copy_memo[key] = deepcopy(value, self.copy_memo)
+            new_value_stack.append(self.copy_memo[key])
+
+        new_instance = type(self)(
+            self.parse_conf,
+            self.lexer,
+            copy(self.state_stack),
+            new_value_stack,
+        )
+
+        self.copy_memo[id(self)] = new_instance
+        return new_instance
+
+
+class FastInteractiveParser(InteractiveParser):
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.parser_state = FastParserState(
+            self.parser_state.parse_conf,
+            self.parser_state.lexer,
+            self.parser_state.state_stack,
+            self.parser_state.value_stack,
+        )
+        self.hash_val = None
+
+    def __hash__(self):
+        if self.hash_val is None:
+            self.hash_val = hash(tuple(self.parser_state.state_stack))
+        return self.hash_val
+
+    def __copy__(self):
+        return type(self)(
+            self.parser,
+            copy(self.parser_state),
+            copy(self.lexer_thread),
+        )
+
+
+def get_pattern_validator(pattern: Pattern):
+    """
+    Accepts a pattern object, either lark.lexer.PatternStr or
+    lark.lexer.PatternRE
+    Returns a function which validates a complete or partial string
+    Returns Tuple with 2 values
+    - 0) The processed portion of the sequence (None if no match at all)
+    - 1) None if doesn't complete terminal, "" if completes terminal with no
+    remainder, or "remainder"
+    """
+    if isinstance(pattern, PatternRE):
+        compiled_pattern = regex.compile(pattern.value)
+
+        @functools.lru_cache(int(1e6))
+        def get_re_matched_parts(seq):
+            # match complete terminal, potentially with leftover seq
+            complete_terminal_match = compiled_pattern.match(seq)
+            if complete_terminal_match:
+                spans = complete_terminal_match.spans()
+                if spans:
+                    span = complete_terminal_match.spans()[0]
+                    if span[0] == 0:
+                        processed_seq = seq[:span[1]]
+                        remainder_seq = seq[span[1]:]
+                        return processed_seq, remainder_seq
+
+            # match doesn't complete terminal, but the sequence is fully
+            # allowed
+            partial_terminal_match = compiled_pattern.fullmatch(seq,
+                                                                partial=True)
+            if partial_terminal_match:
+                return seq, None
+
+            return None, None
+
+        return get_re_matched_parts
+
+    elif isinstance(pattern, PatternStr):
+        base_str = pattern.value
+
+        @functools.lru_cache(int(1e6))
+        def get_str_matched_parts(seq):
+            if seq.startswith(base_str):
+                processed_seq = seq[:len(base_str)]
+                remainder_seq = seq[len(base_str):]
+                return processed_seq, remainder_seq
+            elif base_str.startswith(seq):
+                return seq, None
+            else:
+                return None, None
+
+        return get_str_matched_parts
+
+    else:
+        raise TypeError(f"Invalid pattern type: {type(pattern)}")
+
+
+def method_lru_cache(*lru_args, **lru_kwargs):
+    # https://stackoverflow.com/a/44078118
+    def decorator(func):
+
+        @functools.wraps(func)
+        def wrapped_func(self, *args, **kwargs):
+            self_weak = weakref.ref(self)
+
+            @functools.wraps(func)
+            @functools.lru_cache(*lru_args, **lru_kwargs)
+            def cached_method(*args, **kwargs):
+                return func(self_weak(), *args, **kwargs)
+
+            setattr(self, func.__name__, cached_method)
+            return cached_method(*args, **kwargs)
+
+        return wrapped_func
+
+    return decorator
+
+
+memoize_by_instance = method_lru_cache(int(1e7))
+
+
+class TrieNode:
+
+    def __init__(self):
+        self.children = {}
+        self.is_end_of_word = False
+        self.value = None
+
+
+class Trie:
+
+    def __init__(self):
+        self.root = TrieNode()
+
+    def insert(self, key, value):
+        node = self.root
+        for char in key:
+            if char not in node.children:
+                node.children[char] = TrieNode()
+            node = node.children[char]
+        node.is_end_of_word = True
+        node.value = value
+
+    def get_best(self, word):
+        node = self.root
+        prefix = ""
+        best_prefix = ""
+        best_value = node.value
+        for char in word:
+            if char in node.children:
+                node = node.children[char]
+                prefix += char
+                if node.is_end_of_word:
+                    best_prefix = prefix
+                    best_value = node.value
+            else:
+                break  # break if char not in trie
+
+        remainder = word[len(best_prefix):]
+        return best_prefix, best_value, remainder
+
+
+@dataclass
+class IncrementalParserState:
+    """
+    Parsing utility which enforces uniqueness of
+    - interactive parser state stack
+    - incomplete `partial_token` string
+    the state of the parser and the incomplete token comprise a unique parser
+    state. Core function exposed is `self.step(new_seq)`
+    - Returns a new IncrementalParserState based with new_seq applied
+    Memoization strategy is
+    - 1) Ensure uniqueness of (interactive_parser, partial_token)
+    - 2) Cache class methods via `memoize_by_instance` which considers id(self)
+        and fn arguments
+    """
+
+    # unique state key
+    interactive_parser: FastInteractiveParser
+
+    # function of key
+    terminal_candidates: list
+
+    # shared across instances
+    _ignored_terms: set
+    _seq_validator: dict
+    _memo: dict
+    _full_seq_trie: Trie
+
+    def __repr__(self):
+        class_name = self.__class__.__name__
+        state_stack = self.interactive_parser.parser_state.state_stack
+        return f"{class_name}({state_stack})"
+
+    @classmethod
+    @functools.lru_cache(1000)
+    def from_grammar(cls, grammar: str, start: str):
+        lark_parser = Lark(
+            grammar,
+            regex=True,  # use `regex` not `re`
+            start=start,
+            parser="lalr",
+            cache=True,  # results in 2-3x faster loading
+        )
+        base_interactive_parser = lark_parser.parse_interactive()
+        interactive_parser = FastInteractiveParser(
+            base_interactive_parser.parser,
+            base_interactive_parser.parser_state,
+            base_interactive_parser.lexer_thread)
+        interactive_parser.lexer_thread.state.text = ""
+
+        _seq_validator = {(term.name): get_pattern_validator(term.pattern)
+                          for term in lark_parser.terminals}
+        _seq_validator["$END"] = lambda seq: tuple(
+            ["" if seq is None else None] * 2)
+
+        parser = cls(interactive_parser=interactive_parser,
+                     terminal_candidates=None,
+                     _ignored_terms=set(lark_parser.lexer_conf.ignore),
+                     _seq_validator=_seq_validator,
+                     _memo={},
+                     _full_seq_trie=Trie())
+        parser._full_seq_trie.insert("", parser)
+        return parser
+
+    def new(self, **kwargs):
+        """Cached create now state"""
+        parser_state_key = hash(kwargs["interactive_parser"])
+        if parser_state_key in self._memo:
+            return self._memo[parser_state_key]
+
+        instance_dict = {f.name: getattr(self, f.name) for f in fields(self)}
+        instance_dict.update(kwargs)
+        inst = self.__class__(**instance_dict)
+
+        self._memo[parser_state_key] = inst
+
+        return inst
+
+    def __getitem__(self, full_seq):
+        """Get the parser state, given a full sequence"""
+        # pylint: disable=unused-variable
+        match_seq, parser, remainder_seq = self._full_seq_trie.get_best(
+            full_seq)
+        if parser is None:
+            return
+        if remainder_seq:
+            result = parser.step(remainder_seq)
+            if result is None:
+                return None
+            remainder_seq, parser = result
+            processed_seq = full_seq
+            if remainder_seq:
+                processed_seq = processed_seq[:-len(remainder_seq)]
+            self._full_seq_trie.insert(processed_seq, parser)
+        return remainder_seq, parser
+
+    @memoize_by_instance
+    def step(self, new_seq: str):
+        """
+        - Construct extended (maybe-partial) token candidate
+        - If complete match, create new-terminal incremented parser state
+          - there is leftover from new_seq, recurse on the new parser
+        - If partial matches,
+              return new parser with updated partial token str and updated
+              terminal candidates
+        - If no partial matches, return None
+        """
+        if new_seq == "":
+            return "", self
+
+        best_terminal, processed_seq, remainder_seq = (
+            self.get_best_matched_terminal(new_seq))
+
+        # invalid
+        if best_terminal is None:
+            return None
+
+        # candidate doesn't complete terminal
+        elif remainder_seq is None:
+            return processed_seq, self
+
+        # candidate completes terminal
+        else:
+            new_parser = self._next_with_new_terminal(best_terminal)
+            if remainder_seq == "":
+                return "", new_parser
+            else:
+                return new_parser.step(remainder_seq)
+
+    @memoize_by_instance
+    def _next_with_new_terminal(self, terminal):
+        if terminal in self._ignored_terms:
+            new_interactive_parser = self.interactive_parser
+        else:
+            new_interactive_parser = self.get_stepped_parser_state(terminal)
+
+        return self.new(
+            interactive_parser=new_interactive_parser,
+            terminal_candidates=None,
+        )
+
+    def get_best_matched_terminal(self, seq):
+        for terminal in self.accepts():
+            processed_seq, remainder_seq = self._seq_validator[terminal](seq)
+            if processed_seq:
+                return terminal, processed_seq, remainder_seq
+
+        return None, None, None
+
+    @memoize_by_instance
+    def get_stepped_parser_state(self, new_token_str):
+        ip = copy(self.interactive_parser)
+        ip.feed_token(Token(new_token_str, ""))
+        return ip
+
+    @memoize_by_instance
+    def accepts(self):
+        return set(self.interactive_parser.accepts()) | self._ignored_terms
+
+    @memoize_by_instance
+    def allowed_terminals(self):
+        if self.terminal_candidates is not None:
+            return tuple(sorted(self.terminal_candidates))
+        return tuple(sorted(self.accepts()))
+
+    @memoize_by_instance
+    def is_valid_next_seq(self, new_seq: Optional[str]):
+        if new_seq is None:
+            return "$END" in self.allowed_terminals()
+        return self.step(new_seq) is not None
+
+
+class TokenVocab:
+    """
+    Normalized token vocabulary accounting for whitespace and multiple IDs
+        per token
+    - iter: iterate over normalized token strings
+    - vocab[token_str]: return token id set
+    """
+
+    def __init__(self,
+                 tokenizer: Union[PreTrainedTokenizer,
+                                  PreTrainedTokenizerFast],
+                 legal_chars: Optional[Set[str]] = None):
+
+        self.norm_vocab = collections.defaultdict(set)
+        for token_id in tokenizer.vocab.values():
+            if token_id == tokenizer.eos_token_id:
+                self.norm_vocab[None].add(token_id)
+                continue
+            bos_len = len(tokenizer.bos_token)
+            norm_token = tokenizer.decode([tokenizer.bos_token_id,
+                                           token_id])[bos_len:]
+            if legal_chars is None or all(
+                [char in legal_chars for char in norm_token]):
+                self.norm_vocab[norm_token].add(token_id)
+
+    def __iter__(self):
+        return iter(self.norm_vocab)
+
+    def __getitem__(self, tok_str):
+        return self.norm_vocab[tok_str]
+
+
+class NextTokenValidator:
+
+    def __init__(
+        self,
+        tokenizer,
+        grammar: str,
+        grammar_start: str = "start",
+        legal_chars: Optional[set[str]] = None,
+    ):
+        self.tokenizer = tokenizer
+        self.vocab = TokenVocab(tokenizer, legal_chars=legal_chars)
+        self.root_parser = IncrementalParserState.from_grammar(
+            grammar, grammar_start)
+
+    def get_valid_next_token_strs(self, full_seq):
+        """
+        Generate valid token strings given the full sequence
+        """
+
+        result = self.root_parser[full_seq]
+        if result is None:
+            return []
+        partial_term, parser = result
+        for token in self.vocab:
+            if token is None:
+                if partial_term == "" and parser.is_valid_next_seq(token):
+                    yield None
+            else:
+                if parser.is_valid_next_seq(partial_term + token):
+                    yield token
+
+    def get_valid_next_token_ids(self, full_seq):
+        """
+        Generate valid token ids given the full sequence
+        """
+        for tok_str in self.get_valid_next_token_strs(full_seq):
+            yield from self.vocab[tok_str]
+
+
+class GrammarLogitsProcessor(NextTokenValidator):
+    """
+    Apply NextTokenValidator in __call__ and set excluded tokens logits to -inf
+    """
+
+    def __call__(self, token_ids: List[int],
+                 logits: torch.Tensor) -> torch.Tensor:
+        # get valid token IDs given prior tokens
+        sequence = self.tokenizer.decode(token_ids)
+        valid_token_ids = self.get_valid_next_token_ids(sequence)
+        valid = torch.tensor(list(valid_token_ids), dtype=torch.long)
+
+        # modify logits given valid token IDs
+        N = len(logits)
+        mask = torch.zeros(N, dtype=torch.bool)
+        mask[valid] = True
+        logits[~mask] = float("-inf")
+        return logits
+
+
+@ray.remote
+class GrammarLogitsProcessorActor:
+
+    def __init__(self, *args, **kwargs):
+        self.processor = GrammarLogitsProcessor(*args, **kwargs)
+
+    def process_logits(self, token_ids: List[int],
+                       logits: torch.Tensor) -> torch.Tensor:
+        return self.processor(token_ids, logits)
+
+
+class RayRemoteGrammarLogitsProcessor:
+
+    def __init__(self, *args, **kwargs):
+        self.actor = GrammarLogitsProcessorActor.remote(*args, **kwargs)
+
+    def __call__(self, token_ids: List[int],
+                 logits: torch.Tensor) -> torch.Tensor:
+        logits_cpu = logits.cpu()
+        result_id = self.actor.process_logits.remote(token_ids, logits_cpu)
+        return ray.get(result_id)

+ 13 - 0
aphrodite/endpoints/openai/api_server.py

@@ -35,6 +35,8 @@ from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.transformers_utils.tokenizer import get_tokenizer
 from aphrodite.common.utils import random_uuid
 from aphrodite.common.logits_processor import BiasLogitsProcessor
+from aphrodite.common.grammar import (GrammarLogitsProcessor,
+                                      RayRemoteGrammarLogitsProcessor)
 
 TIMEOUT_KEEP_ALIVE = 5  # seconds
 
@@ -548,6 +550,17 @@ async def create_completion(
                 request.logit_bias.items()))
         logit_processors = [BiasLogitsProcessor(biases)]
 
+    if request.grammar:
+        if engine.worker_use_ray:
+            grammar_logits_processor = RayRemoteGrammarLogitsProcessor(
+                tokenizer=tokenizer, grammar=request.grammar)
+        else:
+            grammar_logits_processor = GrammarLogitsProcessor(
+                tokenizer=tokenizer, grammar=request.grammar)
+        logit_processors = [grammar_logits_processor]
+    else:
+        logit_processors = []
+
     # OpenAI API supports echoing the prompt when max_tokens is 0.
     echo_without_generation = request.echo and request.max_tokens == 0
 

+ 1 - 0
aphrodite/endpoints/openai/protocol.py

@@ -132,6 +132,7 @@ class CompletionRequest(BaseModel):
     custom_token_bans: Optional[List[int]] = Field(default_factory=list)
     skip_special_tokens: Optional[bool] = True
     spaces_between_special_tokens: Optional[bool] = True
+    grammar: Optional[str] = None
 
 
 class LogProbs(BaseModel):

+ 2 - 1
requirements.txt

@@ -16,4 +16,5 @@ colorlog
 einops # for phi
 aioprometheus[starlette] # for prometheus metrics
 triton >= 2.1.0
-pynvml == 11.5.0
+lark == 1.1.8 # for grammars
+pynvml == 11.5.0