1
0

__init__.py 1.2 KB

12345678910111213141516171819202122232425
  1. from typing import Optional, Union
  2. from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
  3. CompletionRequest)
  4. from aphrodite.modeling.guided_decoding.lm_format_enforcer_decoding import (
  5. get_lm_format_enforcer_guided_decoding_logits_processor)
  6. from aphrodite.modeling.guided_decoding.outlines_decoding import (
  7. get_outlines_guided_decoding_logits_processor)
  8. from aphrodite.common.sampling_params import LogitsProcessorFunc
  9. async def get_guided_decoding_logits_processor(
  10. guided_decoding_backend: str, request: Union[CompletionRequest,
  11. ChatCompletionRequest],
  12. tokenizer) -> Optional[LogitsProcessorFunc]:
  13. if guided_decoding_backend == 'outlines':
  14. return await get_outlines_guided_decoding_logits_processor(
  15. request, tokenizer)
  16. if guided_decoding_backend == 'lm-format-enforcer':
  17. return await get_lm_format_enforcer_guided_decoding_logits_processor(
  18. request, tokenizer)
  19. raise ValueError(
  20. f"Unknown guided decoding backend '{guided_decoding_backend}'. "
  21. "Must be one of 'outlines, 'lm-format-enforcer'")