123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- import asyncio
- import concurrent.futures
- from copy import copy
- from enum import Enum
- from functools import lru_cache
- from json import dumps as json_dumps
- from re import escape as regex_escape
- from typing import Union, Tuple
- from pydantic import BaseModel
- from aphrodite.endpoints.openai.protocol import CompletionRequest, ChatCompletionRequest
- from aphrodite.modeling.outlines_logits_processors import JSONLogitsProcessor, RegexLogitsProcessor
- class GuidedDecodingMode(Enum):
- JSON = "json"
- REGEX = "regex"
- CHOICE = "choice"
- global_thread_pool = None # used for generating logits processor fsm
- async def get_guided_decoding_logits_processor(
- request: Union[CompletionRequest, ChatCompletionRequest],
- tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
- """
- Given an OpenAI-compatible request, check for guided decoding parameters
- and get the necessary logits processor for the given guide.
- We cache logit processors by (guide, tokenizer), and on cache hit
- we make a shallow copy to reuse the same underlying FSM.
- """
- global global_thread_pool
- guide, mode = _get_guide_and_mode(request)
- if not guide:
- return None
- if global_thread_pool is None:
- global_thread_pool = concurrent.futures.ThreadPoolExecutor(
- max_workers=2)
- loop = asyncio.get_running_loop()
- result = await loop.run_in_executor(global_thread_pool,
- _get_cached_logits_processor, guide,
- tokenizer, mode)
- logits_processor = copy(result)
- # reset logits processor's internal state
- logits_processor.init_state()
- return logits_processor
- def _get_guide_and_mode(
- request: Union[CompletionRequest, ChatCompletionRequest]
- ) -> Tuple[str, GuidedDecodingMode]:
- if request.guided_json:
- if not isinstance(request.guided_json, (str, dict, BaseModel)):
- raise TypeError("JSON schema must be str, dict, or BaseModel")
- json = request.guided_json
- if isinstance(json, dict):
- # turn dict into hashable string
- json = json_dumps(json, sort_keys=True)
- elif isinstance(json, BaseModel):
- # use pydantic signature so that different model classes
- # with the same fields will get hashed the same
- json = str(json.__signature__)
- return json, GuidedDecodingMode.JSON
- elif request.guided_regex:
- if not isinstance(request.guided_regex, str):
- raise TypeError("Regex must be string")
- return request.guided_regex, GuidedDecodingMode.REGEX
- elif request.guided_choice:
- if not isinstance(request.guided_choice, list):
- raise TypeError("Choices must be a list")
- # choice just uses regex
- choices = [
- regex_escape(str(choice)) for choice in request.guided_choice
- ]
- choices_regex = "(" + "|".join(choices) + ")"
- return choices_regex, GuidedDecodingMode.CHOICE
- else:
- return None, None
- @lru_cache(maxsize=32)
- def _get_cached_logits_processor(guide: str, tokenizer,
- mode: GuidedDecodingMode):
- if mode == GuidedDecodingMode.JSON:
- return JSONLogitsProcessor(guide, tokenizer)
- elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
- return RegexLogitsProcessor(guide, tokenizer)
- else:
- raise ValueError(f"Unknown guided decoding mode {mode}")
|