__init__.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from typing import Optional, Union
  2. from aphrodite.common.sampling_params import LogitsProcessorFunc
  3. from aphrodite.endpoints.openai.protocol import (
  4. ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
  5. CompletionRequest)
  6. from aphrodite.modeling.guided_decoding.guided_fields import (
  7. GuidedDecodingRequest)
  8. from aphrodite.modeling.guided_decoding.outlines_decoding import (
  9. get_local_outlines_guided_decoding_logits_processor,
  10. get_outlines_guided_decoding_logits_processor)
  11. async def get_guided_decoding_logits_processor(
  12. guided_decoding_backend: str, request: Union[CompletionRequest,
  13. ChatCompletionRequest],
  14. tokenizer) -> Optional[LogitsProcessorFunc]:
  15. request = _adapt_request_for_tool_use(request)
  16. if guided_decoding_backend == 'outlines':
  17. return await get_outlines_guided_decoding_logits_processor(
  18. request, tokenizer)
  19. if guided_decoding_backend == 'lm-format-enforcer':
  20. from aphrodite.modeling.guided_decoding.lm_format_enforcer_decoding import ( # noqa
  21. get_lm_format_enforcer_guided_decoding_logits_processor)
  22. return await get_lm_format_enforcer_guided_decoding_logits_processor(
  23. request, tokenizer)
  24. raise ValueError(
  25. f"Unknown guided decoding backend '{guided_decoding_backend}'. "
  26. "Must be one of 'outlines, 'lm-format-enforcer'")
  27. def get_local_guided_decoding_logits_processor(
  28. guided_decoding_backend: str, guided_options: GuidedDecodingRequest,
  29. tokenizer) -> Optional[LogitsProcessorFunc]:
  30. # request = _adapt_request_for_tool_use(request)
  31. if guided_decoding_backend == 'outlines':
  32. return get_local_outlines_guided_decoding_logits_processor(
  33. guided_options, tokenizer)
  34. if guided_decoding_backend == 'lm-format-enforcer':
  35. from aphrodite.modeling.guided_decoding.lm_format_enforcer_decoding import ( # noqa
  36. get_local_lm_format_enforcer_guided_decoding_logits_processor)
  37. return get_local_lm_format_enforcer_guided_decoding_logits_processor(
  38. guided_options, tokenizer)
  39. raise ValueError(
  40. f"Unknown guided decoding backend '{guided_decoding_backend}'. "
  41. "Must be one of 'outlines, 'lm-format-enforcer'")
  42. def _adapt_request_for_tool_use(request: Union[CompletionRequest,
  43. ChatCompletionRequest]):
  44. # the legacy completion API does not support tool use
  45. if type(request) is CompletionRequest:
  46. return request
  47. # user has chosen to not use any tool,
  48. # OR is allowing the model to choose a tool.
  49. if request.tool_choice == "none" or request.tool_choice == "auto":
  50. return request
  51. # user has chosen to use a named tool
  52. if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam:
  53. tool_name = request.tool_choice.function.name
  54. tools = {tool.function.name: tool.function for tool in request.tools}
  55. if tool_name not in tools:
  56. raise ValueError(
  57. f"Tool '{tool_name}' has not been passed in `tools`.")
  58. tool = tools[tool_name]
  59. request.guided_json = tool.parameters
  60. return request