outlines_decoding.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import asyncio
  2. import concurrent.futures
  3. from enum import Enum
  4. from json import dumps as json_dumps
  5. from re import escape as regex_escape
  6. from typing import Tuple, Union
  7. from pydantic import BaseModel
  8. from transformers import PreTrainedTokenizerBase
  9. from aphrodite.endpoints.openai.protocol import (
  10. ChatCompletionRequest,
  11. CompletionRequest,
  12. )
  13. from aphrodite.modeling.guided_decoding.outlines_logits_processors import (
  14. CFGLogitsProcessor,
  15. JSONLogitsProcessor,
  16. RegexLogitsProcessor,
  17. )
  18. class GuidedDecodingMode(Enum):
  19. JSON = "json"
  20. REGEX = "regex"
  21. CHOICE = "choice"
  22. GRAMMAR = "grammar"
  23. # https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
  24. # the main difference is that we changed the start: value to
  25. # start: object | array, so we are denying scalar values as the root of the
  26. # JSON. Starting with scalars as the root seems to cause llama to generate
  27. # without stop.
  28. JSON_GRAMMAR = r"""
  29. ?start: object | array
  30. ?value: object
  31. | array
  32. | UNESCAPED_STRING
  33. | SIGNED_NUMBER -> number
  34. | "true" -> true
  35. | "false" -> false
  36. | "null" -> null
  37. array : "[" [value ("," value)*] "]"
  38. object : "{" [pair ("," pair)*] "}"
  39. pair : UNESCAPED_STRING ":" value
  40. %import common.UNESCAPED_STRING
  41. %import common.SIGNED_NUMBER
  42. %import common.WS
  43. %ignore WS
  44. """
  45. global_thread_pool = None # used for generating logits processor fsm
  46. async def get_outlines_guided_decoding_logits_processor(
  47. request: Union[CompletionRequest,
  48. ChatCompletionRequest], tokenizer: PreTrainedTokenizerBase
  49. ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
  50. None]:
  51. """
  52. Given an OpenAI-compatible request, check for guided decoding parameters
  53. and get the necessary logits processor for the given guide.
  54. We cache logit processors by (guide, tokenizer), and on cache hit
  55. we make a shallow copy to reuse the same underlying FSM.
  56. """
  57. global global_thread_pool
  58. guide, mode = _get_guide_and_mode(request)
  59. if not guide or not mode:
  60. return None
  61. if global_thread_pool is None:
  62. global_thread_pool = concurrent.futures.ThreadPoolExecutor(
  63. max_workers=2)
  64. loop = asyncio.get_running_loop()
  65. return await loop.run_in_executor(global_thread_pool,
  66. _get_logits_processor, guide, tokenizer,
  67. mode, request.guided_whitespace_pattern)
  68. def _get_guide_and_mode(
  69. request: Union[CompletionRequest, ChatCompletionRequest]
  70. ) -> Tuple[str, GuidedDecodingMode]:
  71. if request.guided_json:
  72. json = request.guided_json
  73. if isinstance(json, dict):
  74. # turn dict into hashable string
  75. json = json_dumps(json)
  76. elif isinstance(json, BaseModel):
  77. # use pydantic signature so that different model classes
  78. # with the same fields will get hashed the same
  79. json = str(json.__signature__)
  80. return json, GuidedDecodingMode.JSON
  81. elif request.guided_regex:
  82. return request.guided_regex, GuidedDecodingMode.REGEX
  83. elif request.guided_choice:
  84. # choice just uses regex
  85. choices = [
  86. regex_escape(str(choice)) for choice in request.guided_choice
  87. ]
  88. choices_regex = "(" + "|".join(choices) + ")"
  89. return choices_regex, GuidedDecodingMode.CHOICE
  90. elif request.guided_grammar:
  91. return request.guided_grammar, GuidedDecodingMode.GRAMMAR
  92. elif (request.response_format is not None
  93. and request.response_format.type == "json_object"):
  94. return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
  95. else:
  96. return None, None
  97. def _get_logits_processor(
  98. guide: str, tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode,
  99. whitespace_pattern: Union[str, None]
  100. ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
  101. if mode == GuidedDecodingMode.JSON:
  102. return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
  103. elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
  104. return RegexLogitsProcessor(guide, tokenizer)
  105. elif mode == GuidedDecodingMode.GRAMMAR:
  106. return CFGLogitsProcessor(guide, tokenizer)
  107. else:
  108. raise ValueError(f"Unknown guided decoding mode {mode}")