lm_format_enforcer_decoding.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from functools import lru_cache
  2. from json import loads as json_loads
  3. from typing import Optional, Union
  4. from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
  5. RegexParser, StringParser,
  6. TokenEnforcerTokenizerData, UnionParser)
  7. from pydantic import BaseModel
  8. from transformers import PreTrainedTokenizerBase
  9. from aphrodite.common.sampling_params import LogitsProcessorFunc
  10. from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
  11. CompletionRequest)
  12. from aphrodite.modeling.guided_decoding.guided_fields import (
  13. GuidedDecodingRequest)
  14. from aphrodite.modeling.guided_decoding.lm_format_enforcer_logits_processors import ( # noqa: E501
  15. build_aphrodite_logits_processor,
  16. build_aphrodite_token_enforcer_tokenizer_data)
  17. async def get_lm_format_enforcer_guided_decoding_logits_processor(
  18. request: Union[CompletionRequest, ChatCompletionRequest],
  19. tokenizer) -> Optional[LogitsProcessorFunc]:
  20. """
  21. Given an OpenAI-compatible request, check for guided decoding parameters
  22. and get the necessary logits processor for the given guide.
  23. We cache logit processors by (guide, tokenizer), and on cache hit
  24. we make a shallow copy to reuse the same underlying FSM.
  25. """
  26. tokenizer_data = _cached_build_aphrodite_token_enforcer_tokenizer_data(
  27. tokenizer)
  28. character_level_parser: CharacterLevelParser
  29. if request.guided_json:
  30. schema = _normalize_json_schema_object(request.guided_json)
  31. character_level_parser = JsonSchemaParser(schema)
  32. elif request.guided_choice:
  33. character_level_parser = UnionParser(
  34. [StringParser(choice) for choice in request.guided_choice])
  35. elif request.guided_regex:
  36. character_level_parser = RegexParser(request.guided_regex)
  37. elif request.guided_grammar:
  38. # CFG grammar not supported by LMFE, revert to outlines
  39. from aphrodite.modeling.guided_decoding.outlines_decoding import (
  40. get_outlines_guided_decoding_logits_processor)
  41. return await get_outlines_guided_decoding_logits_processor(
  42. request, tokenizer)
  43. elif (request.response_format is not None
  44. and request.response_format.type == "json_object"):
  45. character_level_parser = JsonSchemaParser(
  46. None) # None means any json object
  47. elif (request.response_format is not None
  48. and request.response_format.type == "json_schema"
  49. and request.response_format.json_schema is not None
  50. and request.response_format.json_schema.json_schema is not None):
  51. schema = _normalize_json_schema_object(
  52. request.response_format.json_schema.json_schema)
  53. character_level_parser = JsonSchemaParser(schema)
  54. else:
  55. return None
  56. logits_processor = build_aphrodite_logits_processor(
  57. tokenizer_data, character_level_parser)
  58. return logits_processor
  59. def get_local_lm_format_enforcer_guided_decoding_logits_processor(
  60. guided_options: GuidedDecodingRequest,
  61. tokenizer) -> Optional[LogitsProcessorFunc]:
  62. """
  63. Given an OpenAI-compatible request, check for guided decoding parameters
  64. and get the necessary logits processor for the given guide.
  65. We cache logit processors by (guide, tokenizer), and on cache hit
  66. we make a shallow copy to reuse the same underlying FSM.
  67. """
  68. tokenizer_data = _cached_build_aphrodite_token_enforcer_tokenizer_data(
  69. tokenizer)
  70. character_level_parser: CharacterLevelParser
  71. if guided_options.guided_json:
  72. schema = _normalize_json_schema_object(guided_options.guided_json)
  73. character_level_parser = JsonSchemaParser(schema)
  74. elif guided_options.guided_choice:
  75. character_level_parser = UnionParser(
  76. [StringParser(choice) for choice in guided_options.guided_choice])
  77. elif guided_options.guided_regex:
  78. character_level_parser = RegexParser(guided_options.guided_regex)
  79. elif guided_options.guided_grammar:
  80. from aphrodite.modeling.guided_decoding.outlines_decoding import (
  81. get_local_outlines_guided_decoding_logits_processor)
  82. # CFG grammar not supported by LMFE, revert to outlines
  83. return get_local_outlines_guided_decoding_logits_processor(
  84. guided_options, tokenizer)
  85. elif guided_options.guided_json_object:
  86. # None means any json object
  87. character_level_parser = JsonSchemaParser(None)
  88. else:
  89. return None
  90. logits_processor = build_aphrodite_logits_processor(
  91. tokenizer_data, character_level_parser)
  92. return logits_processor
  93. def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
  94. if isinstance(schema, str):
  95. return json_loads(schema)
  96. if isinstance(schema, dict):
  97. return schema
  98. if isinstance(schema, BaseModel):
  99. return schema.model_json_schema()
  100. @lru_cache
  101. def _cached_build_aphrodite_token_enforcer_tokenizer_data(
  102. tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData:
  103. return build_aphrodite_token_enforcer_tokenizer_data(tokenizer)