outlines_decoding.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import asyncio
  2. import concurrent.futures
  3. from copy import copy
  4. from enum import Enum
  5. from functools import lru_cache
  6. from json import dumps as json_dumps
  7. from re import escape as regex_escape
  8. from typing import Tuple, Union
  9. from pydantic import BaseModel
  10. from transformers import PreTrainedTokenizerBase
  11. from aphrodite.endpoints.openai.protocol import (
  12. ChatCompletionRequest,
  13. CompletionRequest,
  14. )
  15. from aphrodite.modeling.guided_decoding.outlines_logits_processors import (
  16. CFGLogitsProcessor,
  17. JSONLogitsProcessor,
  18. RegexLogitsProcessor,
  19. )
  20. class GuidedDecodingMode(Enum):
  21. JSON = "json"
  22. REGEX = "regex"
  23. CHOICE = "choice"
  24. GRAMMAR = "grammar"
  25. # https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
  26. # the main difference is that we changed the start: value to
  27. # start: object | array, so we are denying scalar values as the root of the
  28. # JSON. Starting with scalars as the root seems to cause llama to generate
  29. # without stop.
  30. JSON_GRAMMAR = r"""
  31. ?start: object | array
  32. ?value: object
  33. | array
  34. | UNESCAPED_STRING
  35. | SIGNED_NUMBER -> number
  36. | "true" -> true
  37. | "false" -> false
  38. | "null" -> null
  39. array : "[" [value ("," value)*] "]"
  40. object : "{" [pair ("," pair)*] "}"
  41. pair : UNESCAPED_STRING ":" value
  42. %import common.UNESCAPED_STRING
  43. %import common.SIGNED_NUMBER
  44. %import common.WS
  45. %ignore WS
  46. """
  47. global_thread_pool = None # used for generating logits processor fsm
  48. async def get_outlines_guided_decoding_logits_processor(
  49. request: Union[CompletionRequest, ChatCompletionRequest],
  50. tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
  51. """
  52. Given an OpenAI-compatible request, check for guided decoding parameters
  53. and get the necessary logits processor for the given guide.
  54. We cache logit processors by (guide, tokenizer), and on cache hit
  55. we make a shallow copy to reuse the same underlying FSM.
  56. """
  57. global global_thread_pool
  58. guide, mode = _get_guide_and_mode(request)
  59. if not guide:
  60. return None
  61. if global_thread_pool is None:
  62. global_thread_pool = concurrent.futures.ThreadPoolExecutor(
  63. max_workers=2)
  64. loop = asyncio.get_running_loop()
  65. result = await loop.run_in_executor(global_thread_pool,
  66. _get_cached_logits_processor, guide,
  67. tokenizer, mode)
  68. logits_processor = copy(result)
  69. # reset logits processor's internal state
  70. logits_processor.init_state()
  71. return logits_processor
  72. def _get_guide_and_mode(
  73. request: Union[CompletionRequest, ChatCompletionRequest]
  74. ) -> Tuple[str, GuidedDecodingMode]:
  75. if request.guided_json:
  76. json = request.guided_json
  77. if isinstance(json, dict):
  78. # turn dict into hashable string
  79. json = json_dumps(json)
  80. elif isinstance(json, BaseModel):
  81. # use pydantic signature so that different model classes
  82. # with the same fields will get hashed the same
  83. json = str(json.__signature__)
  84. return json, GuidedDecodingMode.JSON
  85. elif request.guided_regex:
  86. return request.guided_regex, GuidedDecodingMode.REGEX
  87. elif request.guided_choice:
  88. # choice just uses regex
  89. choices = [
  90. regex_escape(str(choice)) for choice in request.guided_choice
  91. ]
  92. choices_regex = "(" + "|".join(choices) + ")"
  93. return choices_regex, GuidedDecodingMode.CHOICE
  94. elif request.guided_grammar:
  95. return request.guided_grammar, GuidedDecodingMode.GRAMMAR
  96. elif (request.response_format is not None
  97. and request.response_format.type == "json_object"):
  98. return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
  99. else:
  100. return None, None
  101. @lru_cache(maxsize=32)
  102. def _get_cached_logits_processor(guide: str,
  103. tokenizer: PreTrainedTokenizerBase,
  104. mode: GuidedDecodingMode):
  105. if mode == GuidedDecodingMode.JSON:
  106. return JSONLogitsProcessor(guide, tokenizer)
  107. elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
  108. return RegexLogitsProcessor(guide, tokenizer)
  109. elif mode == GuidedDecodingMode.GRAMMAR:
  110. return CFGLogitsProcessor(guide, tokenizer)
  111. else:
  112. raise ValueError(f"Unknown guided decoding mode {mode}")