lm_format_enforcer_decoding.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from functools import lru_cache
  2. from json import loads as json_loads
  3. from typing import Optional, Union
  4. from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
  5. RegexParser, StringParser,
  6. TokenEnforcerTokenizerData, UnionParser)
  7. from pydantic import BaseModel
  8. from transformers import PreTrainedTokenizerBase
  9. from aphrodite.common.sampling_params import LogitsProcessorFunc
  10. from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
  11. CompletionRequest)
  12. from aphrodite.modeling.guided_decoding.guided_fields import (
  13. GuidedDecodingRequest)
  14. from aphrodite.modeling.guided_decoding.lm_format_enforcer_logits_processors import ( # noqa: E501
  15. build_aphrodite_logits_processor,
  16. build_aphrodite_token_enforcer_tokenizer_data)
  17. from aphrodite.triton_utils import HAS_TRITON
  18. if HAS_TRITON:
  19. from aphrodite.modeling.guided_decoding.outlines_decoding import (
  20. get_local_outlines_guided_decoding_logits_processor,
  21. get_outlines_guided_decoding_logits_processor)
  22. async def get_lm_format_enforcer_guided_decoding_logits_processor(
  23. request: Union[CompletionRequest, ChatCompletionRequest],
  24. tokenizer) -> Optional[LogitsProcessorFunc]:
  25. """
  26. Given an OpenAI-compatible request, check for guided decoding parameters
  27. and get the necessary logits processor for the given guide.
  28. We cache logit processors by (guide, tokenizer), and on cache hit
  29. we make a shallow copy to reuse the same underlying FSM.
  30. """
  31. tokenizer_data = _cached_build_aphrodite_token_enforcer_tokenizer_data(
  32. tokenizer)
  33. character_level_parser: CharacterLevelParser
  34. if request.guided_json:
  35. schema = _normalize_json_schema_object(request.guided_json)
  36. character_level_parser = JsonSchemaParser(schema)
  37. elif request.guided_choice:
  38. character_level_parser = UnionParser(
  39. [StringParser(choice) for choice in request.guided_choice])
  40. elif request.guided_regex:
  41. character_level_parser = RegexParser(request.guided_regex)
  42. elif request.guided_grammar:
  43. # CFG grammar not supported by LMFE, revert to outlines
  44. return await get_outlines_guided_decoding_logits_processor(
  45. request, tokenizer)
  46. elif (request.response_format is not None
  47. and request.response_format.type == "json_object"):
  48. character_level_parser = JsonSchemaParser(
  49. None) # None means any json object
  50. else:
  51. return None
  52. logits_processor = build_aphrodite_logits_processor(
  53. tokenizer_data, character_level_parser)
  54. return logits_processor
  55. def get_local_lm_format_enforcer_guided_decoding_logits_processor(
  56. guided_options: GuidedDecodingRequest,
  57. tokenizer) -> Optional[LogitsProcessorFunc]:
  58. """
  59. Given an OpenAI-compatible request, check for guided decoding parameters
  60. and get the necessary logits processor for the given guide.
  61. We cache logit processors by (guide, tokenizer), and on cache hit
  62. we make a shallow copy to reuse the same underlying FSM.
  63. """
  64. tokenizer_data = _cached_build_aphrodite_token_enforcer_tokenizer_data(
  65. tokenizer)
  66. character_level_parser: CharacterLevelParser
  67. if guided_options.guided_json:
  68. schema = _normalize_json_schema_object(guided_options.guided_json)
  69. character_level_parser = JsonSchemaParser(schema)
  70. elif guided_options.guided_choice:
  71. character_level_parser = UnionParser(
  72. [StringParser(choice) for choice in guided_options.guided_choice])
  73. elif guided_options.guided_regex:
  74. character_level_parser = RegexParser(guided_options.guided_regex)
  75. elif guided_options.guided_grammar:
  76. # CFG grammar not supported by LMFE, revert to outlines
  77. return get_local_outlines_guided_decoding_logits_processor(
  78. guided_options, tokenizer)
  79. elif guided_options.guided_json_object:
  80. # None means any json object
  81. character_level_parser = JsonSchemaParser(None)
  82. else:
  83. return None
  84. logits_processor = build_aphrodite_logits_processor(
  85. tokenizer_data, character_level_parser)
  86. return logits_processor
  87. def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
  88. if isinstance(schema, str):
  89. return json_loads(schema)
  90. if isinstance(schema, dict):
  91. return schema
  92. if isinstance(schema, BaseModel):
  93. return schema.model_json_schema()
  94. @lru_cache
  95. def _cached_build_aphrodite_token_enforcer_tokenizer_data(
  96. tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData:
  97. return build_aphrodite_token_enforcer_tokenizer_data(tokenizer)