from typing import Optional, Union from aphrodite.common.sampling_params import LogitsProcessorFunc from aphrodite.endpoints.openai.protocol import ( ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, CompletionRequest) from aphrodite.modeling.guided_decoding.guided_fields import ( GuidedDecodingRequest) from aphrodite.triton_utils import HAS_TRITON if HAS_TRITON: from aphrodite.modeling.guided_decoding.outlines_decoding import ( get_local_outlines_guided_decoding_logits_processor, get_outlines_guided_decoding_logits_processor) async def get_guided_decoding_logits_processor( guided_decoding_backend: str, request: Union[CompletionRequest, ChatCompletionRequest], tokenizer) -> Optional[LogitsProcessorFunc]: request = _adapt_request_for_tool_use(request) if guided_decoding_backend == 'outlines': if HAS_TRITON: return await get_outlines_guided_decoding_logits_processor( request, tokenizer) else: pass if guided_decoding_backend == 'lm-format-enforcer': from aphrodite.modeling.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_lm_format_enforcer_guided_decoding_logits_processor) return await get_lm_format_enforcer_guided_decoding_logits_processor( request, tokenizer) raise ValueError( f"Unknown guided decoding backend '{guided_decoding_backend}'. " "Must be one of 'outlines, 'lm-format-enforcer'") def get_local_guided_decoding_logits_processor( guided_decoding_backend: str, guided_options: GuidedDecodingRequest, tokenizer) -> Optional[LogitsProcessorFunc]: # request = _adapt_request_for_tool_use(request) if guided_decoding_backend == 'outlines': return get_local_outlines_guided_decoding_logits_processor( guided_options, tokenizer) if guided_decoding_backend == 'lm-format-enforcer': from aphrodite.modeling.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( guided_options, tokenizer) raise ValueError( f"Unknown guided decoding backend '{guided_decoding_backend}'. " "Must be one of 'outlines, 'lm-format-enforcer'") def _adapt_request_for_tool_use(request: Union[CompletionRequest, ChatCompletionRequest]): # the legacy completion API does not support tool use if type(request) is CompletionRequest: return request # user has chosen to not use any tool if request.tool_choice == "none": return request # user has chosen to use a named tool if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam: tool_name = request.tool_choice.function.name tools = {tool.function.name: tool.function for tool in request.tools} if tool_name not in tools: raise ValueError( f"Tool '{tool_name}' has not been passed in `tools`.") tool = tools[tool_name] request.guided_json = tool.parameters return request