__init__.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from typing import Optional, Union
  2. from aphrodite.common.sampling_params import LogitsProcessorFunc
  3. from aphrodite.endpoints.openai.protocol import (
  4. ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
  5. CompletionRequest)
  6. from aphrodite.modeling.guided_decoding.guided_fields import (
  7. GuidedDecodingRequest)
  8. from aphrodite.modeling.guided_decoding.outlines_decoding import (
  9. get_local_outlines_guided_decoding_logits_processor,
  10. get_outlines_guided_decoding_logits_processor)
  11. async def get_guided_decoding_logits_processor(
  12. guided_decoding_backend: str, request: Union[CompletionRequest,
  13. ChatCompletionRequest],
  14. tokenizer) -> Optional[LogitsProcessorFunc]:
  15. request = _adapt_request_for_tool_use(request)
  16. if guided_decoding_backend == 'outlines':
  17. return await get_outlines_guided_decoding_logits_processor(
  18. request, tokenizer)
  19. if guided_decoding_backend == 'lm-format-enforcer':
  20. from aphrodite.modeling.guided_decoding.lm_format_enforcer_decoding import ( # noqa
  21. get_lm_format_enforcer_guided_decoding_logits_processor)
  22. return await get_lm_format_enforcer_guided_decoding_logits_processor(
  23. request, tokenizer)
  24. raise ValueError(
  25. f"Unknown guided decoding backend '{guided_decoding_backend}'. "
  26. "Must be one of 'outlines, 'lm-format-enforcer'")
  27. def get_local_guided_decoding_logits_processor(
  28. guided_decoding_backend: str, guided_options: GuidedDecodingRequest,
  29. tokenizer) -> Optional[LogitsProcessorFunc]:
  30. # request = _adapt_request_for_tool_use(request)
  31. if guided_decoding_backend == 'outlines':
  32. return get_local_outlines_guided_decoding_logits_processor(
  33. guided_options, tokenizer)
  34. if guided_decoding_backend == 'lm-format-enforcer':
  35. from aphrodite.modeling.guided_decoding.lm_format_enforcer_decoding import ( # noqa
  36. get_local_lm_format_enforcer_guided_decoding_logits_processor)
  37. return get_local_lm_format_enforcer_guided_decoding_logits_processor(
  38. guided_options, tokenizer)
  39. raise ValueError(
  40. f"Unknown guided decoding backend '{guided_decoding_backend}'. "
  41. "Must be one of 'outlines, 'lm-format-enforcer'")
  42. def _adapt_request_for_tool_use(request: Union[CompletionRequest,
  43. ChatCompletionRequest]):
  44. # the legacy completion API does not support tool use
  45. if type(request) is CompletionRequest:
  46. return request
  47. # user has chosen to not use any tool
  48. if request.tool_choice == "none":
  49. return request
  50. # user has chosen to use a named tool
  51. if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam:
  52. tool_name = request.tool_choice.function.name
  53. tools = {tool.function.name: tool.function for tool in request.tools}
  54. if tool_name not in tools:
  55. raise ValueError(
  56. f"Tool '{tool_name}' has not been passed in `tools`.")
  57. tool = tools[tool_name]
  58. request.guided_json = tool.parameters
  59. return request