lm_format_enforcer_decoding.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from functools import lru_cache
  2. from json import loads as json_loads
  3. from typing import Optional, Union
  4. from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
  5. RegexParser, StringParser,
  6. TokenEnforcerTokenizerData, UnionParser)
  7. from pydantic import BaseModel
  8. from transformers import PreTrainedTokenizerBase
  9. from aphrodite.common.sampling_params import LogitsProcessorFunc
  10. from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
  11. CompletionRequest)
  12. from aphrodite.modeling.guided_decoding.guided_fields import (
  13. GuidedDecodingRequest)
  14. from aphrodite.modeling.guided_decoding.lm_format_enforcer_logits_processors import ( # noqa: E501
  15. build_aphrodite_logits_processor,
  16. build_aphrodite_token_enforcer_tokenizer_data)
  17. from aphrodite.modeling.guided_decoding.outlines_decoding import (
  18. get_local_outlines_guided_decoding_logits_processor,
  19. get_outlines_guided_decoding_logits_processor)
  20. async def get_lm_format_enforcer_guided_decoding_logits_processor(
  21. request: Union[CompletionRequest, ChatCompletionRequest],
  22. tokenizer) -> Optional[LogitsProcessorFunc]:
  23. """
  24. Given an OpenAI-compatible request, check for guided decoding parameters
  25. and get the necessary logits processor for the given guide.
  26. We cache logit processors by (guide, tokenizer), and on cache hit
  27. we make a shallow copy to reuse the same underlying FSM.
  28. """
  29. tokenizer_data = _cached_build_aphrodite_token_enforcer_tokenizer_data(
  30. tokenizer)
  31. character_level_parser: CharacterLevelParser
  32. if request.guided_json:
  33. schema = _normalize_json_schema_object(request.guided_json)
  34. character_level_parser = JsonSchemaParser(schema)
  35. elif request.guided_choice:
  36. character_level_parser = UnionParser(
  37. [StringParser(choice) for choice in request.guided_choice])
  38. elif request.guided_regex:
  39. character_level_parser = RegexParser(request.guided_regex)
  40. elif request.guided_grammar:
  41. # CFG grammar not supported by LMFE, revert to outlines
  42. return await get_outlines_guided_decoding_logits_processor(
  43. request, tokenizer)
  44. elif (request.response_format is not None
  45. and request.response_format.type == "json_object"):
  46. character_level_parser = JsonSchemaParser(
  47. None) # None means any json object
  48. else:
  49. return None
  50. logits_processor = build_aphrodite_logits_processor(
  51. tokenizer_data, character_level_parser)
  52. return logits_processor
  53. def get_local_lm_format_enforcer_guided_decoding_logits_processor(
  54. guided_options: GuidedDecodingRequest,
  55. tokenizer) -> Optional[LogitsProcessorFunc]:
  56. """
  57. Given an OpenAI-compatible request, check for guided decoding parameters
  58. and get the necessary logits processor for the given guide.
  59. We cache logit processors by (guide, tokenizer), and on cache hit
  60. we make a shallow copy to reuse the same underlying FSM.
  61. """
  62. tokenizer_data = _cached_build_aphrodite_token_enforcer_tokenizer_data(
  63. tokenizer)
  64. character_level_parser: CharacterLevelParser
  65. if guided_options.guided_json:
  66. schema = _normalize_json_schema_object(guided_options.guided_json)
  67. character_level_parser = JsonSchemaParser(schema)
  68. elif guided_options.guided_choice:
  69. character_level_parser = UnionParser(
  70. [StringParser(choice) for choice in guided_options.guided_choice])
  71. elif guided_options.guided_regex:
  72. character_level_parser = RegexParser(guided_options.guided_regex)
  73. elif guided_options.guided_grammar:
  74. # CFG grammar not supported by LMFE, revert to outlines
  75. return get_local_outlines_guided_decoding_logits_processor(
  76. guided_options, tokenizer)
  77. elif guided_options.guided_json_object:
  78. # None means any json object
  79. character_level_parser = JsonSchemaParser(None)
  80. else:
  81. return None
  82. logits_processor = build_aphrodite_logits_processor(
  83. tokenizer_data, character_level_parser)
  84. return logits_processor
  85. def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
  86. if isinstance(schema, str):
  87. return json_loads(schema)
  88. if isinstance(schema, dict):
  89. return schema
  90. if isinstance(schema, BaseModel):
  91. return schema.model_json_schema()
  92. @lru_cache
  93. def _cached_build_aphrodite_token_enforcer_tokenizer_data(
  94. tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData:
  95. return build_aphrodite_token_enforcer_tokenizer_data(tokenizer)