1
0

__init__.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from typing import Optional, Union
  2. from aphrodite.common.sampling_params import LogitsProcessorFunc
  3. from aphrodite.endpoints.openai.protocol import (
  4. ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
  5. CompletionRequest)
  6. from aphrodite.modeling.guided_decoding.guided_fields import (
  7. GuidedDecodingRequest)
  8. async def get_guided_decoding_logits_processor(
  9. guided_decoding_backend: str, request: Union[CompletionRequest,
  10. ChatCompletionRequest],
  11. tokenizer) -> Optional[LogitsProcessorFunc]:
  12. request = _adapt_request_for_tool_use(request)
  13. if guided_decoding_backend == 'outlines':
  14. from aphrodite.modeling.guided_decoding.outlines_decoding import (
  15. get_outlines_guided_decoding_logits_processor)
  16. return await get_outlines_guided_decoding_logits_processor(
  17. request, tokenizer)
  18. if guided_decoding_backend == 'lm-format-enforcer':
  19. pass
  20. if guided_decoding_backend == 'lm-format-enforcer':
  21. from aphrodite.modeling.guided_decoding.lm_format_enforcer_decoding import ( # noqa
  22. get_lm_format_enforcer_guided_decoding_logits_processor)
  23. return await get_lm_format_enforcer_guided_decoding_logits_processor(
  24. request, tokenizer)
  25. raise ValueError(
  26. f"Unknown guided decoding backend '{guided_decoding_backend}'. "
  27. "Must be one of 'outlines, 'lm-format-enforcer'")
  28. def get_local_guided_decoding_logits_processor(
  29. guided_decoding_backend: str, guided_options: GuidedDecodingRequest,
  30. tokenizer) -> Optional[LogitsProcessorFunc]:
  31. # request = _adapt_request_for_tool_use(request)
  32. if guided_decoding_backend == 'outlines':
  33. from aphrodite.modeling.guided_decoding.outlines_decoding import (
  34. get_local_outlines_guided_decoding_logits_processor)
  35. return get_local_outlines_guided_decoding_logits_processor(
  36. guided_options, tokenizer)
  37. if guided_decoding_backend == 'lm-format-enforcer':
  38. from aphrodite.modeling.guided_decoding.lm_format_enforcer_decoding import ( # noqa
  39. get_local_lm_format_enforcer_guided_decoding_logits_processor)
  40. return get_local_lm_format_enforcer_guided_decoding_logits_processor(
  41. guided_options, tokenizer)
  42. raise ValueError(
  43. f"Unknown guided decoding backend '{guided_decoding_backend}'. "
  44. "Must be one of 'outlines, 'lm-format-enforcer'")
  45. def _adapt_request_for_tool_use(request: Union[CompletionRequest,
  46. ChatCompletionRequest]):
  47. # the legacy completion API does not support tool use
  48. if type(request) is CompletionRequest:
  49. return request
  50. # user has chosen to not use any tool,
  51. # OR is allowing the model to choose a tool.
  52. if request.tool_choice == "none" or request.tool_choice == "auto":
  53. return request
  54. # user has chosen to use a named tool
  55. if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam:
  56. tool_name = request.tool_choice.function.name
  57. tools = {tool.function.name: tool.function for tool in request.tools}
  58. if tool_name not in tools:
  59. raise ValueError(
  60. f"Tool '{tool_name}' has not been passed in `tools`.")
  61. tool = tools[tool_name]
  62. request.guided_json = tool.parameters
  63. return request