12345678910111213141516171819202122232425 |
- from typing import Optional, Union
- from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
- CompletionRequest)
- from aphrodite.modeling.guided_decoding.lm_format_enforcer_decoding import (
- get_lm_format_enforcer_guided_decoding_logits_processor)
- from aphrodite.modeling.guided_decoding.outlines_decoding import (
- get_outlines_guided_decoding_logits_processor)
- from aphrodite.common.sampling_params import LogitsProcessorFunc
- async def get_guided_decoding_logits_processor(
- guided_decoding_backend: str, request: Union[CompletionRequest,
- ChatCompletionRequest],
- tokenizer) -> Optional[LogitsProcessorFunc]:
- if guided_decoding_backend == 'outlines':
- return await get_outlines_guided_decoding_logits_processor(
- request, tokenizer)
- if guided_decoding_backend == 'lm-format-enforcer':
- return await get_lm_format_enforcer_guided_decoding_logits_processor(
- request, tokenizer)
- raise ValueError(
- f"Unknown guided decoding backend '{guided_decoding_backend}'. "
- "Must be one of 'outlines, 'lm-format-enforcer'")
|