123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167 |
- import asyncio
- import concurrent.futures
- from enum import Enum
- from json import dumps as json_dumps
- from re import escape as regex_escape
- from typing import Tuple, Union
- from pydantic import BaseModel
- from transformers import PreTrainedTokenizerBase
- from aphrodite.endpoints.openai.protocol import (
- ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
- CompletionRequest)
- from aphrodite.modeling.guided_decoding.guided_fields import (
- GuidedDecodingRequest)
- from aphrodite.modeling.guided_decoding.outlines_logits_processors import (
- CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
- class GuidedDecodingMode(Enum):
- JSON = "json"
- REGEX = "regex"
- CHOICE = "choice"
- GRAMMAR = "grammar"
- # https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
- # the main difference is that we changed the start: value to
- # start: object | array, so we are denying scalar values as the root of the
- # JSON. Starting with scalars as the root seems to cause llama to generate
- # without stop.
- JSON_GRAMMAR = r"""
- ?start: object | array
- ?value: object
- | array
- | UNESCAPED_STRING
- | SIGNED_NUMBER -> number
- | "true" -> true
- | "false" -> false
- | "null" -> null
- array : "[" [value ("," value)*] "]"
- object : "{" [pair ("," pair)*] "}"
- pair : UNESCAPED_STRING ":" value
- %import common.UNESCAPED_STRING
- %import common.SIGNED_NUMBER
- %import common.WS
- %ignore WS
- """
- global_thread_pool = None # used for generating logits processor fsm
- async def get_outlines_guided_decoding_logits_processor(
- request: Union[CompletionRequest,
- ChatCompletionRequest], tokenizer: PreTrainedTokenizerBase
- ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
- None]:
- """
- Given an OpenAI-compatible request, check for guided decoding parameters
- and get the necessary logits processor for the given guide.
- We cache logit processors by (guide, tokenizer), and on cache hit
- we make a shallow copy to reuse the same underlying FSM.
- """
- global global_thread_pool
- guide, mode = _get_guide_and_mode(request)
- if not guide or not mode:
- return None
- if global_thread_pool is None:
- global_thread_pool = concurrent.futures.ThreadPoolExecutor(
- max_workers=2)
- loop = asyncio.get_running_loop()
- return await loop.run_in_executor(global_thread_pool,
- _get_logits_processor, guide, tokenizer,
- mode, request.guided_whitespace_pattern)
- def get_local_outlines_guided_decoding_logits_processor(
- guided_options: GuidedDecodingRequest, tokenizer: PreTrainedTokenizerBase
- ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
- None]:
- """
- Given an OpenAI-compatible request, check for guided decoding parameters
- and get the necessary logits processor for the given guide.
- We cache logit processors by (guide, tokenizer), and on cache hit
- we make a shallow copy to reuse the same underlying FSM.
- """
- guide, mode = _get_guide_and_mode(guided_options)
- if not guide or not mode:
- return None
- return _get_logits_processor(guide, tokenizer, mode,
- guided_options.guided_whitespace_pattern)
- def _get_guide_and_mode(
- request: Union[CompletionRequest, ChatCompletionRequest,
- GuidedDecodingRequest]
- ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
- # if the request is a chat completion request, AND the tool choice is a
- # named tool choice, do guided decoding
- # using that tool as the JSON schema
- if isinstance(request, ChatCompletionRequest) and isinstance(
- request.tool_choice, ChatCompletionNamedToolChoiceParam):
- # Guided generation for tools/functions parameters
- if request.tool_choice.type == "function":
- for tool in request.tools:
- if (tool.type == "function" and tool.function.name
- == request.tool_choice.function.name):
- json = json_dumps(tool.function.parameters, sort_keys=True)
- return json, GuidedDecodingMode.JSON
- return None, None
- elif request.guided_json:
- if isinstance(request.guided_json, dict):
- # turn dict into hashable string
- json = json_dumps(request.guided_json)
- elif isinstance(request.guided_json, BaseModel):
- # use pydantic signature so that different model classes
- # with the same fields will get hashed the same
- json = str(request.guided_json.__signature__)
- else:
- json = request.guided_json
- return json, GuidedDecodingMode.JSON
- elif request.guided_regex:
- return request.guided_regex, GuidedDecodingMode.REGEX
- elif request.guided_choice:
- # choice just uses regex
- choices = [
- regex_escape(str(choice)) for choice in request.guided_choice
- ]
- choices_regex = "(" + "|".join(choices) + ")"
- return choices_regex, GuidedDecodingMode.CHOICE
- elif request.guided_grammar:
- return request.guided_grammar, GuidedDecodingMode.GRAMMAR
- elif (not isinstance(request, GuidedDecodingRequest)
- and request.response_format is not None
- and request.response_format.type == "json_object"):
- return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
- elif (not isinstance(request, GuidedDecodingRequest)
- and request.response_format is not None
- and request.response_format.type == "json_schema"
- and request.response_format.json_schema is not None
- and request.response_format.json_schema.json_schema is not None):
- json = json_dumps(request.response_format.json_schema.json_schema)
- return json, GuidedDecodingMode.JSON
- else:
- return None, None
- def _get_logits_processor(
- guide: str, tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode,
- whitespace_pattern: Union[str, None]
- ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
- if mode == GuidedDecodingMode.JSON:
- return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
- elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
- return RegexLogitsProcessor(guide, tokenizer)
- elif mode == GuidedDecodingMode.GRAMMAR:
- return CFGLogitsProcessor(guide, tokenizer)
- else:
- raise ValueError(f"Unknown guided decoding mode {mode}")
|