123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- # Copyright 2024- the Outlines developers
- # This file is adapted from
- # https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- # http://www.apache.org/licenses/LICENSE-2.0
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import json
- import math
- from collections import defaultdict
- from typing import Union, DefaultDict, Dict, List, Optional
- import torch
- from pydantic import BaseModel
- from outlines.fsm.fsm import RegexFSM
- from outlines.fsm.json_schema import build_regex_from_schema
- class RegexLogitsProcessor:
- def __init__(self, regex_string: str, tokenizer):
- """Compile the FSM that drives the regex-structured generation.
- Parameters
- ----------
- regex_string
- A string that represents a regular expression
- tokenizer
- The model's tokenizer
- """
- tokenizer = self.adapt_tokenizer(tokenizer)
- fsm = RegexFSM(regex_string, tokenizer)
- self.fsm = fsm
- def init_state(self):
- """Initialize the FSM states."""
- self.fsm_state: DefaultDict[int, int] = defaultdict(int)
- def __call__(self, input_ids: List[int],
- scores: torch.Tensor) -> torch.Tensor:
- """Use the FSM to bias the logits before sampling the next token."""
- seq_id = hash(tuple(input_ids))
- if len(input_ids) == 0:
- self.init_state()
- else:
- last_token = input_ids[-1]
- last_seq_id = hash(tuple(input_ids[:-1]))
- self.fsm_state[seq_id] = self.fsm.next_state(
- self.fsm_state[last_seq_id], last_token)
- allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
- mask = torch.full((scores.shape[-1], ),
- -math.inf,
- device=scores.device)
- mask[allowed_tokens] = 0
- scores.add_(mask)
- return scores
- def adapt_tokenizer(self, tokenizer):
- """Adapt vLLM's tokenizer to use to compile the FSM.
- The API of Outlines tokenizers is slightly different to that of
- `transformers`. In addition we need to handle the missing spaces to
- Llama's tokenizer to be able to compile FSMs for this model.
- """
- tokenizer.vocabulary = tokenizer.get_vocab()
- tokenizer.special_tokens = set(tokenizer.all_special_tokens)
- def convert_token_to_string(token: str) -> str:
- from transformers.file_utils import SPIECE_UNDERLINE
- string = tokenizer.convert_tokens_to_string([token])
- # A hack to handle missing spaces to HF's Llama tokenizers
- if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
- return " " + string
- return string
- tokenizer.convert_token_to_string = convert_token_to_string
- return tokenizer
- class JSONLogitsProcessor(RegexLogitsProcessor):
- def __init__(
- self,
- schema: Union[str, Dict, BaseModel],
- tokenizer,
- whitespace_pattern: Optional[str] = None,
- ):
- """Compile the FSM that drives the JSON-guided generation.
- Parameters
- ----------
- schema
- A JSON schema that encodes the structure we want the model to
- generate
- tokenizer
- The model's tokenizer
- whitespace_pattern
- Pattern to use for JSON syntactic whitespace
- (doesn't impact string literals)
- Example: allow only a single space or newline with
- `whitespace_pattern=r"[\n ]?"`
- """
- if isinstance(schema, type(BaseModel)):
- schema_str = json.dumps(schema.model_json_schema())
- elif isinstance(schema, Dict):
- schema_str = json.dumps(schema)
- elif isinstance(schema, str):
- schema_str = schema
- else:
- raise ValueError(
- f"Cannot parse schema {schema}. The schema must be either " +
- "a Pydantic object, a dictionary or a string that contains " +
- "the JSON Schema specification")
- regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
- super().__init__(regex_string, tokenizer)
|