outlines_decoding.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import asyncio
  2. import concurrent.futures
  3. from enum import Enum
  4. from json import dumps as json_dumps
  5. from re import escape as regex_escape
  6. from typing import Tuple, Union
  7. from pydantic import BaseModel
  8. from transformers import PreTrainedTokenizerBase
  9. from aphrodite.endpoints.openai.protocol import (
  10. ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
  11. CompletionRequest)
  12. from aphrodite.modeling.guided_decoding.guided_fields import (
  13. GuidedDecodingRequest)
  14. from aphrodite.modeling.guided_decoding.outlines_logits_processors import (
  15. CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
  16. class GuidedDecodingMode(Enum):
  17. JSON = "json"
  18. REGEX = "regex"
  19. CHOICE = "choice"
  20. GRAMMAR = "grammar"
  21. # https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
  22. # the main difference is that we changed the start: value to
  23. # start: object | array, so we are denying scalar values as the root of the
  24. # JSON. Starting with scalars as the root seems to cause llama to generate
  25. # without stop.
  26. JSON_GRAMMAR = r"""
  27. ?start: object | array
  28. ?value: object
  29. | array
  30. | UNESCAPED_STRING
  31. | SIGNED_NUMBER -> number
  32. | "true" -> true
  33. | "false" -> false
  34. | "null" -> null
  35. array : "[" [value ("," value)*] "]"
  36. object : "{" [pair ("," pair)*] "}"
  37. pair : UNESCAPED_STRING ":" value
  38. %import common.UNESCAPED_STRING
  39. %import common.SIGNED_NUMBER
  40. %import common.WS
  41. %ignore WS
  42. """
  43. global_thread_pool = None # used for generating logits processor fsm
  44. async def get_outlines_guided_decoding_logits_processor(
  45. request: Union[CompletionRequest,
  46. ChatCompletionRequest], tokenizer: PreTrainedTokenizerBase
  47. ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
  48. None]:
  49. """
  50. Given an OpenAI-compatible request, check for guided decoding parameters
  51. and get the necessary logits processor for the given guide.
  52. We cache logit processors by (guide, tokenizer), and on cache hit
  53. we make a shallow copy to reuse the same underlying FSM.
  54. """
  55. global global_thread_pool
  56. guide, mode = _get_guide_and_mode(request)
  57. if not guide or not mode:
  58. return None
  59. if global_thread_pool is None:
  60. global_thread_pool = concurrent.futures.ThreadPoolExecutor(
  61. max_workers=2)
  62. loop = asyncio.get_running_loop()
  63. return await loop.run_in_executor(global_thread_pool,
  64. _get_logits_processor, guide, tokenizer,
  65. mode, request.guided_whitespace_pattern)
  66. def get_local_outlines_guided_decoding_logits_processor(
  67. guided_options: GuidedDecodingRequest, tokenizer: PreTrainedTokenizerBase
  68. ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
  69. None]:
  70. """
  71. Given an OpenAI-compatible request, check for guided decoding parameters
  72. and get the necessary logits processor for the given guide.
  73. We cache logit processors by (guide, tokenizer), and on cache hit
  74. we make a shallow copy to reuse the same underlying FSM.
  75. """
  76. guide, mode = _get_guide_and_mode(guided_options)
  77. if not guide or not mode:
  78. return None
  79. return _get_logits_processor(guide, tokenizer, mode,
  80. guided_options.guided_whitespace_pattern)
  81. def _get_guide_and_mode(
  82. request: Union[CompletionRequest, ChatCompletionRequest,
  83. GuidedDecodingRequest]
  84. ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
  85. # if the request is a chat completion request, AND the tool choice is a
  86. # named tool choice, do guided decoding
  87. # using that tool as the JSON schema
  88. if isinstance(request, ChatCompletionRequest) and isinstance(
  89. request.tool_choice, ChatCompletionNamedToolChoiceParam):
  90. # Guided generation for tools/functions parameters
  91. if request.tool_choice.type == "function":
  92. for tool in request.tools:
  93. if (tool.type == "function" and tool.function.name
  94. == request.tool_choice.function.name):
  95. json = json_dumps(tool.function.parameters, sort_keys=True)
  96. return json, GuidedDecodingMode.JSON
  97. return None, None
  98. elif request.guided_json:
  99. if isinstance(request.guided_json, dict):
  100. # turn dict into hashable string
  101. json = json_dumps(request.guided_json)
  102. elif isinstance(request.guided_json, BaseModel):
  103. # use pydantic signature so that different model classes
  104. # with the same fields will get hashed the same
  105. json = str(request.guided_json.__signature__)
  106. else:
  107. json = request.guided_json
  108. return json, GuidedDecodingMode.JSON
  109. elif request.guided_regex:
  110. return request.guided_regex, GuidedDecodingMode.REGEX
  111. elif request.guided_choice:
  112. # choice just uses regex
  113. choices = [
  114. regex_escape(str(choice)) for choice in request.guided_choice
  115. ]
  116. choices_regex = "(" + "|".join(choices) + ")"
  117. return choices_regex, GuidedDecodingMode.CHOICE
  118. elif request.guided_grammar:
  119. return request.guided_grammar, GuidedDecodingMode.GRAMMAR
  120. elif (not isinstance(request, GuidedDecodingRequest)
  121. and request.response_format is not None
  122. and request.response_format.type == "json_object"):
  123. return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
  124. elif (not isinstance(request, GuidedDecodingRequest)
  125. and request.response_format is not None
  126. and request.response_format.type == "json_schema"
  127. and request.response_format.json_schema is not None
  128. and request.response_format.json_schema.json_schema is not None):
  129. json = json_dumps(request.response_format.json_schema.json_schema)
  130. return json, GuidedDecodingMode.JSON
  131. else:
  132. return None, None
  133. def _get_logits_processor(
  134. guide: str, tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode,
  135. whitespace_pattern: Union[str, None]
  136. ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
  137. if mode == GuidedDecodingMode.JSON:
  138. return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
  139. elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
  140. return RegexLogitsProcessor(guide, tokenizer)
  141. elif mode == GuidedDecodingMode.GRAMMAR:
  142. return CFGLogitsProcessor(guide, tokenizer)
  143. else:
  144. raise ValueError(f"Unknown guided decoding mode {mode}")