outlines_decoding.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import asyncio
  2. import concurrent.futures
  3. from enum import Enum
  4. from json import dumps as json_dumps
  5. from re import escape as regex_escape
  6. from typing import Tuple, Union
  7. from pydantic import BaseModel
  8. from transformers import PreTrainedTokenizerBase
  9. from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
  10. CompletionRequest)
  11. from aphrodite.modeling.guided_decoding.guided_fields import (
  12. GuidedDecodingRequest)
  13. from aphrodite.modeling.guided_decoding.outlines_logits_processors import (
  14. CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
  15. class GuidedDecodingMode(Enum):
  16. JSON = "json"
  17. REGEX = "regex"
  18. CHOICE = "choice"
  19. GRAMMAR = "grammar"
  20. # https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
  21. # the main difference is that we changed the start: value to
  22. # start: object | array, so we are denying scalar values as the root of the
  23. # JSON. Starting with scalars as the root seems to cause llama to generate
  24. # without stop.
  25. JSON_GRAMMAR = r"""
  26. ?start: object | array
  27. ?value: object
  28. | array
  29. | UNESCAPED_STRING
  30. | SIGNED_NUMBER -> number
  31. | "true" -> true
  32. | "false" -> false
  33. | "null" -> null
  34. array : "[" [value ("," value)*] "]"
  35. object : "{" [pair ("," pair)*] "}"
  36. pair : UNESCAPED_STRING ":" value
  37. %import common.UNESCAPED_STRING
  38. %import common.SIGNED_NUMBER
  39. %import common.WS
  40. %ignore WS
  41. """
  42. global_thread_pool = None # used for generating logits processor fsm
  43. async def get_outlines_guided_decoding_logits_processor(
  44. request: Union[CompletionRequest,
  45. ChatCompletionRequest], tokenizer: PreTrainedTokenizerBase
  46. ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
  47. None]:
  48. """
  49. Given an OpenAI-compatible request, check for guided decoding parameters
  50. and get the necessary logits processor for the given guide.
  51. We cache logit processors by (guide, tokenizer), and on cache hit
  52. we make a shallow copy to reuse the same underlying FSM.
  53. """
  54. global global_thread_pool
  55. guide, mode = _get_guide_and_mode(request)
  56. if not guide or not mode:
  57. return None
  58. if global_thread_pool is None:
  59. global_thread_pool = concurrent.futures.ThreadPoolExecutor(
  60. max_workers=2)
  61. loop = asyncio.get_running_loop()
  62. return await loop.run_in_executor(global_thread_pool,
  63. _get_logits_processor, guide, tokenizer,
  64. mode, request.guided_whitespace_pattern)
  65. def get_local_outlines_guided_decoding_logits_processor(
  66. guided_options: GuidedDecodingRequest, tokenizer: PreTrainedTokenizerBase
  67. ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
  68. None]:
  69. """
  70. Given an OpenAI-compatible request, check for guided decoding parameters
  71. and get the necessary logits processor for the given guide.
  72. We cache logit processors by (guide, tokenizer), and on cache hit
  73. we make a shallow copy to reuse the same underlying FSM.
  74. """
  75. guide, mode = _get_guide_and_mode(guided_options)
  76. if not guide or not mode:
  77. return None
  78. return _get_logits_processor(guide, tokenizer, mode,
  79. guided_options.guided_whitespace_pattern)
  80. def _get_guide_and_mode(
  81. request: Union[CompletionRequest, ChatCompletionRequest,
  82. GuidedDecodingRequest]
  83. ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
  84. if request.guided_json:
  85. json = request.guided_json
  86. if isinstance(json, dict):
  87. # turn dict into hashable string
  88. json = json_dumps(json)
  89. elif isinstance(json, BaseModel):
  90. # use pydantic signature so that different model classes
  91. # with the same fields will get hashed the same
  92. json = str(json.__signature__)
  93. return json, GuidedDecodingMode.JSON
  94. elif request.guided_regex:
  95. return request.guided_regex, GuidedDecodingMode.REGEX
  96. elif request.guided_choice:
  97. # choice just uses regex
  98. choices = [
  99. regex_escape(str(choice)) for choice in request.guided_choice
  100. ]
  101. choices_regex = "(" + "|".join(choices) + ")"
  102. return choices_regex, GuidedDecodingMode.CHOICE
  103. elif request.guided_grammar:
  104. return request.guided_grammar, GuidedDecodingMode.GRAMMAR
  105. elif (not isinstance(request, GuidedDecodingRequest)
  106. and request.response_format is not None
  107. and request.response_format.type == "json_object"):
  108. return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
  109. elif (not isinstance(request, GuidedDecodingRequest)
  110. and request.response_format is not None
  111. and request.response_format.type == "json_schema"
  112. and request.response_format.json_schema is not None
  113. and request.response_format.json_schema.json_schema is not None):
  114. json = json_dumps(request.response_format.json_schema.json_schema)
  115. return json, GuidedDecodingMode.JSON
  116. else:
  117. return None, None
  118. def _get_logits_processor(
  119. guide: str, tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode,
  120. whitespace_pattern: Union[str, None]
  121. ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
  122. if mode == GuidedDecodingMode.JSON:
  123. return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
  124. elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
  125. return RegexLogitsProcessor(guide, tokenizer)
  126. elif mode == GuidedDecodingMode.GRAMMAR:
  127. return CFGLogitsProcessor(guide, tokenizer)
  128. else:
  129. raise ValueError(f"Unknown guided decoding mode {mode}")