__init__.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. from typing import Optional, Union
  2. from aphrodite.common.sampling_params import LogitsProcessorFunc
  3. from aphrodite.endpoints.openai.protocol import (
  4. ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
  5. CompletionRequest)
  6. from aphrodite.modeling.guided_decoding.guided_fields import (
  7. GuidedDecodingRequest)
  8. async def get_guided_decoding_logits_processor(
  9. guided_decoding_backend: str, request: Union[CompletionRequest,
  10. ChatCompletionRequest],
  11. tokenizer) -> Optional[LogitsProcessorFunc]:
  12. request = _adapt_request_for_tool_use(request)
  13. if guided_decoding_backend == 'outlines':
  14. from aphrodite.modeling.guided_decoding.outlines_decoding import (
  15. get_outlines_guided_decoding_logits_processor)
  16. return await get_outlines_guided_decoding_logits_processor(
  17. request, tokenizer)
  18. if guided_decoding_backend == 'lm-format-enforcer':
  19. pass
  20. if guided_decoding_backend == 'lm-format-enforcer':
  21. from aphrodite.modeling.guided_decoding.lm_format_enforcer_decoding import ( # noqa
  22. get_lm_format_enforcer_guided_decoding_logits_processor)
  23. return await get_lm_format_enforcer_guided_decoding_logits_processor(
  24. request, tokenizer)
  25. raise ValueError(
  26. f"Unknown guided decoding backend '{guided_decoding_backend}'. "
  27. "Must be one of 'outlines, 'lm-format-enforcer'")
  28. def get_local_guided_decoding_logits_processor(
  29. guided_decoding_backend: str, guided_options: GuidedDecodingRequest,
  30. tokenizer) -> Optional[LogitsProcessorFunc]:
  31. # request = _adapt_request_for_tool_use(request)
  32. if guided_decoding_backend == 'outlines':
  33. from aphrodite.modeling.guided_decoding.outlines_decoding import (
  34. get_local_outlines_guided_decoding_logits_processor)
  35. return get_local_outlines_guided_decoding_logits_processor(
  36. guided_options, tokenizer)
  37. if guided_decoding_backend == 'lm-format-enforcer':
  38. from aphrodite.modeling.guided_decoding.lm_format_enforcer_decoding import ( # noqa
  39. get_local_lm_format_enforcer_guided_decoding_logits_processor)
  40. return get_local_lm_format_enforcer_guided_decoding_logits_processor(
  41. guided_options, tokenizer)
  42. raise ValueError(
  43. f"Unknown guided decoding backend '{guided_decoding_backend}'. "
  44. "Must be one of 'outlines, 'lm-format-enforcer'")
  45. def _adapt_request_for_tool_use(request: Union[CompletionRequest,
  46. ChatCompletionRequest]):
  47. # the legacy completion API does not support tool use
  48. if type(request) is CompletionRequest:
  49. return request
  50. # user has chosen to not use any tool
  51. if request.tool_choice == "none":
  52. return request
  53. # user has chosen to use a named tool
  54. if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam:
  55. tool_name = request.tool_choice.function.name
  56. tools = {tool.function.name: tool.function for tool in request.tools}
  57. if tool_name not in tools:
  58. raise ValueError(
  59. f"Tool '{tool_name}' has not been passed in `tools`.")
  60. tool = tools[tool_name]
  61. request.guided_json = tool.parameters
  62. return request