lm_format_enforcer_decoding.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from functools import lru_cache
  2. from json import loads as json_loads
  3. from typing import Optional, Union
  4. from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
  5. RegexParser, StringParser,
  6. TokenEnforcerTokenizerData, UnionParser)
  7. from pydantic import BaseModel
  8. from transformers import PreTrainedTokenizerBase
  9. from aphrodite.common.sampling_params import LogitsProcessorFunc
  10. from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
  11. CompletionRequest)
  12. from aphrodite.modeling.guided_decoding.lm_format_enforcer_logits_processors import ( # noqa: E501
  13. build_aphrodite_logits_processor,
  14. build_aphrodite_token_enforcer_tokenizer_data)
  15. from aphrodite.modeling.guided_decoding.outlines_decoding import \
  16. get_outlines_guided_decoding_logits_processor
  17. async def get_lm_format_enforcer_guided_decoding_logits_processor(
  18. request: Union[CompletionRequest, ChatCompletionRequest],
  19. tokenizer) -> Optional[LogitsProcessorFunc]:
  20. """
  21. Given an OpenAI-compatible request, check for guided decoding parameters
  22. and get the necessary logits processor for the given guide.
  23. We cache logit processors by (guide, tokenizer), and on cache hit
  24. we make a shallow copy to reuse the same underlying FSM.
  25. """
  26. tokenizer_data = _cached_build_aphrodite_token_enforcer_tokenizer_data(
  27. tokenizer)
  28. character_level_parser: CharacterLevelParser
  29. if request.guided_json:
  30. schema = _normalize_json_schema_object(request.guided_json)
  31. character_level_parser = JsonSchemaParser(schema)
  32. elif request.guided_choice:
  33. character_level_parser = UnionParser(
  34. [StringParser(choice) for choice in request.guided_choice])
  35. elif request.guided_regex:
  36. character_level_parser = RegexParser(request.guided_regex)
  37. elif request.guided_grammar:
  38. # CFG grammar not supported by LMFE, revert to outlines
  39. return await get_outlines_guided_decoding_logits_processor(
  40. request, tokenizer)
  41. elif (request.response_format is not None
  42. and request.response_format.type == "json_object"):
  43. character_level_parser = JsonSchemaParser(
  44. None) # None means any json object
  45. else:
  46. return None
  47. logits_processor = build_aphrodite_logits_processor(
  48. tokenizer_data, character_level_parser)
  49. return logits_processor
  50. def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
  51. if isinstance(schema, str):
  52. return json_loads(schema)
  53. if isinstance(schema, dict):
  54. return schema
  55. if isinstance(schema, BaseModel):
  56. return schema.model_json_schema()
  57. @lru_cache
  58. def _cached_build_aphrodite_token_enforcer_tokenizer_data(
  59. tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData:
  60. return build_aphrodite_token_enforcer_tokenizer_data(tokenizer)