outlines_decoding.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import asyncio
  2. import concurrent.futures
  3. from copy import copy
  4. from enum import Enum
  5. from functools import lru_cache
  6. from json import dumps as json_dumps
  7. from re import escape as regex_escape
  8. from typing import Tuple, Union
  9. from pydantic import BaseModel
  10. from transformers import PreTrainedTokenizerBase
  11. from aphrodite.endpoints.openai.protocol import (
  12. ChatCompletionRequest,
  13. CompletionRequest,
  14. )
  15. from aphrodite.modeling.guided_decoding.outlines_logits_processors import (
  16. CFGLogitsProcessor,
  17. JSONLogitsProcessor,
  18. RegexLogitsProcessor,
  19. )
  20. class GuidedDecodingMode(Enum):
  21. JSON = "json"
  22. REGEX = "regex"
  23. CHOICE = "choice"
  24. GRAMMAR = "grammar"
  25. # https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
  26. # the main difference is that we changed the start: value to
  27. # start: object | array, so we are denying scalar values as the root of the
  28. # JSON. Starting with scalars as the root seems to cause llama to generate
  29. # without stop.
  30. JSON_GRAMMAR = r"""
  31. ?start: object | array
  32. ?value: object
  33. | array
  34. | UNESCAPED_STRING
  35. | SIGNED_NUMBER -> number
  36. | "true" -> true
  37. | "false" -> false
  38. | "null" -> null
  39. array : "[" [value ("," value)*] "]"
  40. object : "{" [pair ("," pair)*] "}"
  41. pair : UNESCAPED_STRING ":" value
  42. %import common.UNESCAPED_STRING
  43. %import common.SIGNED_NUMBER
  44. %import common.WS
  45. %ignore WS
  46. """
  47. global_thread_pool = None # used for generating logits processor fsm
  48. async def get_outlines_guided_decoding_logits_processor(
  49. request: Union[CompletionRequest, ChatCompletionRequest],
  50. tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
  51. """
  52. Given an OpenAI-compatible request, check for guided decoding parameters
  53. and get the necessary logits processor for the given guide.
  54. We cache logit processors by (guide, tokenizer), and on cache hit
  55. we make a shallow copy to reuse the same underlying FSM.
  56. """
  57. global global_thread_pool
  58. guide, mode = _get_guide_and_mode(request)
  59. if not guide:
  60. return None
  61. if global_thread_pool is None:
  62. global_thread_pool = concurrent.futures.ThreadPoolExecutor(
  63. max_workers=2)
  64. loop = asyncio.get_running_loop()
  65. result = await loop.run_in_executor(global_thread_pool,
  66. _get_cached_logits_processor, guide,
  67. tokenizer, mode,
  68. request.guided_whitespace_pattern)
  69. logits_processor = copy(result)
  70. # reset logits processor's internal state
  71. logits_processor.init_state()
  72. return logits_processor
  73. def _get_guide_and_mode(
  74. request: Union[CompletionRequest, ChatCompletionRequest]
  75. ) -> Tuple[str, GuidedDecodingMode]:
  76. if request.guided_json:
  77. json = request.guided_json
  78. if isinstance(json, dict):
  79. # turn dict into hashable string
  80. json = json_dumps(json)
  81. elif isinstance(json, BaseModel):
  82. # use pydantic signature so that different model classes
  83. # with the same fields will get hashed the same
  84. json = str(json.__signature__)
  85. return json, GuidedDecodingMode.JSON
  86. elif request.guided_regex:
  87. return request.guided_regex, GuidedDecodingMode.REGEX
  88. elif request.guided_choice:
  89. # choice just uses regex
  90. choices = [
  91. regex_escape(str(choice)) for choice in request.guided_choice
  92. ]
  93. choices_regex = "(" + "|".join(choices) + ")"
  94. return choices_regex, GuidedDecodingMode.CHOICE
  95. elif request.guided_grammar:
  96. return request.guided_grammar, GuidedDecodingMode.GRAMMAR
  97. elif (request.response_format is not None
  98. and request.response_format.type == "json_object"):
  99. return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
  100. else:
  101. return None, None
  102. @lru_cache(maxsize=32)
  103. def _get_cached_logits_processor(guide: str,
  104. tokenizer: PreTrainedTokenizerBase,
  105. mode: GuidedDecodingMode,
  106. whitespace_pattern: Union[str, None]):
  107. if mode == GuidedDecodingMode.JSON:
  108. return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
  109. elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
  110. return RegexLogitsProcessor(guide, tokenizer)
  111. elif mode == GuidedDecodingMode.GRAMMAR:
  112. return CFGLogitsProcessor(guide, tokenizer)
  113. else:
  114. raise ValueError(f"Unknown guided decoding mode {mode}")