123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- import asyncio
- import concurrent.futures
- from copy import copy
- from enum import Enum
- from functools import lru_cache
- from json import dumps as json_dumps
- from re import escape as regex_escape
- from typing import Tuple, Union
- from pydantic import BaseModel
- from transformers import PreTrainedTokenizerBase
- from aphrodite.endpoints.openai.protocol import (
- ChatCompletionRequest,
- CompletionRequest,
- )
- from aphrodite.modeling.outlines_logits_processors import (
- CFGLogitsProcessor,
- JSONLogitsProcessor,
- RegexLogitsProcessor,
- )
- class GuidedDecodingMode(Enum):
- JSON = "json"
- REGEX = "regex"
- CHOICE = "choice"
- GRAMMAR = "grammar"
- # https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
- # the main difference is that we changed the start: value to
- # start: object | array, so we are denying scalar values as the root of the
- # JSON. Starting with scalars as the root seems to cause llama to generate
- # without stop.
- JSON_GRAMMAR = r"""
- ?start: object | array
- ?value: object
- | array
- | UNESCAPED_STRING
- | SIGNED_NUMBER -> number
- | "true" -> true
- | "false" -> false
- | "null" -> null
- array : "[" [value ("," value)*] "]"
- object : "{" [pair ("," pair)*] "}"
- pair : UNESCAPED_STRING ":" value
- %import common.UNESCAPED_STRING
- %import common.SIGNED_NUMBER
- %import common.WS
- %ignore WS
- """
- global_thread_pool = None # used for generating logits processor fsm
- async def get_guided_decoding_logits_processor(
- request: Union[CompletionRequest, ChatCompletionRequest],
- tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
- """
- Given an OpenAI-compatible request, check for guided decoding parameters
- and get the necessary logits processor for the given guide.
- We cache logit processors by (guide, tokenizer), and on cache hit
- we make a shallow copy to reuse the same underlying FSM.
- """
- global global_thread_pool
- guide, mode = _get_guide_and_mode(request)
- if not guide:
- return None
- if global_thread_pool is None:
- global_thread_pool = concurrent.futures.ThreadPoolExecutor(
- max_workers=2)
- loop = asyncio.get_running_loop()
- result = await loop.run_in_executor(global_thread_pool,
- _get_cached_logits_processor, guide,
- tokenizer, mode)
- logits_processor = copy(result)
- # reset logits processor's internal state
- logits_processor.init_state()
- return logits_processor
- def _get_guide_and_mode(
- request: Union[CompletionRequest, ChatCompletionRequest]
- ) -> Tuple[str, GuidedDecodingMode]:
- if request.guided_json:
- json = request.guided_json
- if isinstance(json, dict):
- # turn dict into hashable string
- json = json_dumps(json)
- elif isinstance(json, BaseModel):
- # use pydantic signature so that different model classes
- # with the same fields will get hashed the same
- json = str(json.__signature__)
- return json, GuidedDecodingMode.JSON
- elif request.guided_regex:
- return request.guided_regex, GuidedDecodingMode.REGEX
- elif request.guided_choice:
- # choice just uses regex
- choices = [
- regex_escape(str(choice)) for choice in request.guided_choice
- ]
- choices_regex = "(" + "|".join(choices) + ")"
- return choices_regex, GuidedDecodingMode.CHOICE
- elif request.guided_grammar:
- return request.guided_grammar, GuidedDecodingMode.GRAMMAR
- elif (request.response_format is not None
- and request.response_format.type == "json_object"):
- return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
- else:
- return None, None
- @lru_cache(maxsize=32)
- def _get_cached_logits_processor(guide: str,
- tokenizer: PreTrainedTokenizerBase,
- mode: GuidedDecodingMode):
- if mode == GuidedDecodingMode.JSON:
- return JSONLogitsProcessor(guide, tokenizer)
- elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
- return RegexLogitsProcessor(guide, tokenizer)
- elif mode == GuidedDecodingMode.GRAMMAR:
- return CFGLogitsProcessor(guide, tokenizer)
- else:
- raise ValueError(f"Unknown guided decoding mode {mode}")
|