1
0

__init__.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. from typing import Optional, Union
  2. from aphrodite.common.sampling_params import LogitsProcessorFunc
  3. from aphrodite.endpoints.openai.protocol import (
  4. ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
  5. CompletionRequest)
  6. from aphrodite.modeling.guided_decoding.lm_format_enforcer_decoding import \
  7. get_lm_format_enforcer_guided_decoding_logits_processor
  8. from aphrodite.modeling.guided_decoding.outlines_decoding import \
  9. get_outlines_guided_decoding_logits_processor
  10. async def get_guided_decoding_logits_processor(
  11. guided_decoding_backend: str, request: Union[CompletionRequest,
  12. ChatCompletionRequest],
  13. tokenizer) -> Optional[LogitsProcessorFunc]:
  14. request = _adapt_request_for_tool_use(request)
  15. if guided_decoding_backend == 'outlines':
  16. return await get_outlines_guided_decoding_logits_processor(
  17. request, tokenizer)
  18. if guided_decoding_backend == 'lm-format-enforcer':
  19. return await get_lm_format_enforcer_guided_decoding_logits_processor(
  20. request, tokenizer)
  21. raise ValueError(
  22. f"Unknown guided decoding backend '{guided_decoding_backend}'. "
  23. "Must be one of 'outlines, 'lm-format-enforcer'")
  24. def _adapt_request_for_tool_use(request: Union[CompletionRequest,
  25. ChatCompletionRequest]):
  26. # the legacy completion API does not support tool use
  27. if type(request) is CompletionRequest:
  28. return request
  29. # user has chosen to not use any tool
  30. if request.tool_choice == "none":
  31. return request
  32. # user has chosen to use a named tool
  33. if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam:
  34. tool_name = request.tool_choice.function.name
  35. tools = {tool.function.name: tool.function for tool in request.tools}
  36. if tool_name not in tools:
  37. raise ValueError(
  38. f"Tool '{tool_name}' has not been passed in `tools`.")
  39. tool = tools[tool_name]
  40. request.guided_json = tool.parameters
  41. return request