outlines_decoding.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import asyncio
  2. import concurrent.futures
  3. from copy import copy
  4. from enum import Enum
  5. from functools import lru_cache
  6. from json import dumps as json_dumps
  7. from re import escape as regex_escape
  8. from typing import Union, Tuple
  9. from pydantic import BaseModel
  10. from aphrodite.endpoints.openai.protocol import (
  11. CompletionRequest,
  12. ChatCompletionRequest,
  13. )
  14. from aphrodite.modeling.outlines_logits_processors import (
  15. JSONLogitsProcessor,
  16. RegexLogitsProcessor,
  17. )
  18. class GuidedDecodingMode(Enum):
  19. JSON = "json"
  20. REGEX = "regex"
  21. CHOICE = "choice"
  22. global_thread_pool = None # used for generating logits processor fsm
  23. async def get_guided_decoding_logits_processor(
  24. request: Union[CompletionRequest, ChatCompletionRequest],
  25. tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
  26. """
  27. Given an OpenAI-compatible request, check for guided decoding parameters
  28. and get the necessary logits processor for the given guide.
  29. We cache logit processors by (guide, tokenizer), and on cache hit
  30. we make a shallow copy to reuse the same underlying FSM.
  31. """
  32. global global_thread_pool
  33. guide, mode = _get_guide_and_mode(request)
  34. if not guide:
  35. return None
  36. if global_thread_pool is None:
  37. global_thread_pool = concurrent.futures.ThreadPoolExecutor(
  38. max_workers=2)
  39. loop = asyncio.get_running_loop()
  40. result = await loop.run_in_executor(global_thread_pool,
  41. _get_cached_logits_processor, guide,
  42. tokenizer, mode)
  43. logits_processor = copy(result)
  44. # reset logits processor's internal state
  45. logits_processor.init_state()
  46. return logits_processor
  47. def _get_guide_and_mode(
  48. request: Union[CompletionRequest, ChatCompletionRequest],
  49. ) -> Tuple[str, GuidedDecodingMode]:
  50. if request.guided_json:
  51. if not isinstance(request.guided_json, (str, dict, BaseModel)):
  52. raise TypeError("JSON schema must be str, dict, or BaseModel")
  53. json = request.guided_json
  54. if isinstance(json, dict):
  55. # turn dict into hashable string
  56. json = json_dumps(json, sort_keys=True)
  57. elif isinstance(json, BaseModel):
  58. # use pydantic signature so that different model classes
  59. # with the same fields will get hashed the same
  60. json = str(json.__signature__)
  61. return json, GuidedDecodingMode.JSON
  62. elif request.guided_regex:
  63. if not isinstance(request.guided_regex, str):
  64. raise TypeError("Regex must be string")
  65. return request.guided_regex, GuidedDecodingMode.REGEX
  66. elif request.guided_choice:
  67. if not isinstance(request.guided_choice, list):
  68. raise TypeError("Choices must be a list")
  69. # choice just uses regex
  70. choices = [
  71. regex_escape(str(choice)) for choice in request.guided_choice
  72. ]
  73. choices_regex = "(" + "|".join(choices) + ")"
  74. return choices_regex, GuidedDecodingMode.CHOICE
  75. else:
  76. return None, None
  77. @lru_cache(maxsize=32)
  78. def _get_cached_logits_processor(guide: str, tokenizer,
  79. mode: GuidedDecodingMode):
  80. if mode == GuidedDecodingMode.JSON:
  81. return JSONLogitsProcessor(guide, tokenizer)
  82. elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
  83. return RegexLogitsProcessor(guide, tokenizer)
  84. else:
  85. raise ValueError(f"Unknown guided decoding mode {mode}")