12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- from functools import lru_cache
- from json import loads as json_loads
- from typing import Optional, Union
- from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
- RegexParser, StringParser,
- TokenEnforcerTokenizerData, UnionParser)
- from pydantic import BaseModel
- from transformers import PreTrainedTokenizerBase
- from aphrodite.common.sampling_params import LogitsProcessorFunc
- from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
- CompletionRequest)
- from aphrodite.modeling.guided_decoding.lm_format_enforcer_logits_processors import ( # noqa: E501
- build_aphrodite_logits_processor,
- build_aphrodite_token_enforcer_tokenizer_data)
- from aphrodite.modeling.guided_decoding.outlines_decoding import \
- get_outlines_guided_decoding_logits_processor
- async def get_lm_format_enforcer_guided_decoding_logits_processor(
- request: Union[CompletionRequest, ChatCompletionRequest],
- tokenizer) -> Optional[LogitsProcessorFunc]:
- """
- Given an OpenAI-compatible request, check for guided decoding parameters
- and get the necessary logits processor for the given guide.
- We cache logit processors by (guide, tokenizer), and on cache hit
- we make a shallow copy to reuse the same underlying FSM.
- """
- tokenizer_data = _cached_build_aphrodite_token_enforcer_tokenizer_data(
- tokenizer)
- character_level_parser: CharacterLevelParser
- if request.guided_json:
- schema = _normalize_json_schema_object(request.guided_json)
- character_level_parser = JsonSchemaParser(schema)
- elif request.guided_choice:
- character_level_parser = UnionParser(
- [StringParser(choice) for choice in request.guided_choice])
- elif request.guided_regex:
- character_level_parser = RegexParser(request.guided_regex)
- elif request.guided_grammar:
- # CFG grammar not supported by LMFE, revert to outlines
- return await get_outlines_guided_decoding_logits_processor(
- request, tokenizer)
- elif (request.response_format is not None
- and request.response_format.type == "json_object"):
- character_level_parser = JsonSchemaParser(
- None) # None means any json object
- else:
- return None
- logits_processor = build_aphrodite_logits_processor(
- tokenizer_data, character_level_parser)
- return logits_processor
- def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
- if isinstance(schema, str):
- return json_loads(schema)
- if isinstance(schema, dict):
- return schema
- if isinstance(schema, BaseModel):
- return schema.model_json_schema()
- @lru_cache
- def _cached_build_aphrodite_token_enforcer_tokenizer_data(
- tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData:
- return build_aphrodite_token_enforcer_tokenizer_data(tokenizer)
|