outlines_decoding.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import asyncio
  2. import concurrent.futures
  3. from copy import copy
  4. from enum import Enum
  5. from functools import lru_cache
  6. from json import dumps as json_dumps
  7. from re import escape as regex_escape
  8. from typing import Union, Tuple
  9. from pydantic import BaseModel
  10. from aphrodite.endpoints.openai.protocol import CompletionRequest, ChatCompletionRequest
  11. from aphrodite.modeling.outlines_logits_processors import JSONLogitsProcessor, RegexLogitsProcessor
  12. class GuidedDecodingMode(Enum):
  13. JSON = "json"
  14. REGEX = "regex"
  15. CHOICE = "choice"
  16. global_thread_pool = None # used for generating logits processor fsm
  17. async def get_guided_decoding_logits_processor(
  18. request: Union[CompletionRequest, ChatCompletionRequest],
  19. tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
  20. """
  21. Given an OpenAI-compatible request, check for guided decoding parameters
  22. and get the necessary logits processor for the given guide.
  23. We cache logit processors by (guide, tokenizer), and on cache hit
  24. we make a shallow copy to reuse the same underlying FSM.
  25. """
  26. global global_thread_pool
  27. guide, mode = _get_guide_and_mode(request)
  28. if not guide:
  29. return None
  30. if global_thread_pool is None:
  31. global_thread_pool = concurrent.futures.ThreadPoolExecutor(
  32. max_workers=2)
  33. loop = asyncio.get_running_loop()
  34. result = await loop.run_in_executor(global_thread_pool,
  35. _get_cached_logits_processor, guide,
  36. tokenizer, mode)
  37. logits_processor = copy(result)
  38. # reset logits processor's internal state
  39. logits_processor.init_state()
  40. return logits_processor
  41. def _get_guide_and_mode(
  42. request: Union[CompletionRequest, ChatCompletionRequest]
  43. ) -> Tuple[str, GuidedDecodingMode]:
  44. if request.guided_json:
  45. if not isinstance(request.guided_json, (str, dict, BaseModel)):
  46. raise TypeError("JSON schema must be str, dict, or BaseModel")
  47. json = request.guided_json
  48. if isinstance(json, dict):
  49. # turn dict into hashable string
  50. json = json_dumps(json, sort_keys=True)
  51. elif isinstance(json, BaseModel):
  52. # use pydantic signature so that different model classes
  53. # with the same fields will get hashed the same
  54. json = str(json.__signature__)
  55. return json, GuidedDecodingMode.JSON
  56. elif request.guided_regex:
  57. if not isinstance(request.guided_regex, str):
  58. raise TypeError("Regex must be string")
  59. return request.guided_regex, GuidedDecodingMode.REGEX
  60. elif request.guided_choice:
  61. if not isinstance(request.guided_choice, list):
  62. raise TypeError("Choices must be a list")
  63. # choice just uses regex
  64. choices = [
  65. regex_escape(str(choice)) for choice in request.guided_choice
  66. ]
  67. choices_regex = "(" + "|".join(choices) + ")"
  68. return choices_regex, GuidedDecodingMode.CHOICE
  69. else:
  70. return None, None
  71. @lru_cache(maxsize=32)
  72. def _get_cached_logits_processor(guide: str, tokenizer,
  73. mode: GuidedDecodingMode):
  74. if mode == GuidedDecodingMode.JSON:
  75. return JSONLogitsProcessor(guide, tokenizer)
  76. elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
  77. return RegexLogitsProcessor(guide, tokenizer)
  78. else:
  79. raise ValueError(f"Unknown guided decoding mode {mode}")