123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- from typing import Optional, Union
- from aphrodite.common.sampling_params import LogitsProcessorFunc
- from aphrodite.endpoints.openai.protocol import (
- ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
- CompletionRequest)
- from aphrodite.modeling.guided_decoding.guided_fields import (
- GuidedDecodingRequest)
- from aphrodite.triton_utils import HAS_TRITON
- if HAS_TRITON:
- from aphrodite.modeling.guided_decoding.outlines_decoding import (
- get_local_outlines_guided_decoding_logits_processor,
- get_outlines_guided_decoding_logits_processor)
- async def get_guided_decoding_logits_processor(
- guided_decoding_backend: str, request: Union[CompletionRequest,
- ChatCompletionRequest],
- tokenizer) -> Optional[LogitsProcessorFunc]:
- request = _adapt_request_for_tool_use(request)
- if guided_decoding_backend == 'outlines':
- if HAS_TRITON:
- return await get_outlines_guided_decoding_logits_processor(
- request, tokenizer)
- else:
- pass
- if guided_decoding_backend == 'lm-format-enforcer':
- from aphrodite.modeling.guided_decoding.lm_format_enforcer_decoding import ( # noqa
- get_lm_format_enforcer_guided_decoding_logits_processor)
- return await get_lm_format_enforcer_guided_decoding_logits_processor(
- request, tokenizer)
- raise ValueError(
- f"Unknown guided decoding backend '{guided_decoding_backend}'. "
- "Must be one of 'outlines, 'lm-format-enforcer'")
- def get_local_guided_decoding_logits_processor(
- guided_decoding_backend: str, guided_options: GuidedDecodingRequest,
- tokenizer) -> Optional[LogitsProcessorFunc]:
- # request = _adapt_request_for_tool_use(request)
- if guided_decoding_backend == 'outlines':
- return get_local_outlines_guided_decoding_logits_processor(
- guided_options, tokenizer)
- if guided_decoding_backend == 'lm-format-enforcer':
- from aphrodite.modeling.guided_decoding.lm_format_enforcer_decoding import ( # noqa
- get_local_lm_format_enforcer_guided_decoding_logits_processor)
- return get_local_lm_format_enforcer_guided_decoding_logits_processor(
- guided_options, tokenizer)
- raise ValueError(
- f"Unknown guided decoding backend '{guided_decoding_backend}'. "
- "Must be one of 'outlines, 'lm-format-enforcer'")
- def _adapt_request_for_tool_use(request: Union[CompletionRequest,
- ChatCompletionRequest]):
- # the legacy completion API does not support tool use
- if type(request) is CompletionRequest:
- return request
- # user has chosen to not use any tool
- if request.tool_choice == "none":
- return request
- # user has chosen to use a named tool
- if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam:
- tool_name = request.tool_choice.function.name
- tools = {tool.function.name: tool.function for tool in request.tools}
- if tool_name not in tools:
- raise ValueError(
- f"Tool '{tool_name}' has not been passed in `tools`.")
- tool = tools[tool_name]
- request.guided_json = tool.parameters
- return request
|