__init__.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from typing import Optional, Union
  2. from aphrodite.common.sampling_params import LogitsProcessorFunc
  3. from aphrodite.endpoints.openai.protocol import (
  4. ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
  5. CompletionRequest)
  6. from aphrodite.modeling.guided_decoding.guided_fields import (
  7. GuidedDecodingRequest)
  8. from aphrodite.triton_utils import HAS_TRITON
  9. if HAS_TRITON:
  10. from aphrodite.modeling.guided_decoding.outlines_decoding import (
  11. get_local_outlines_guided_decoding_logits_processor,
  12. get_outlines_guided_decoding_logits_processor)
  13. async def get_guided_decoding_logits_processor(
  14. guided_decoding_backend: str, request: Union[CompletionRequest,
  15. ChatCompletionRequest],
  16. tokenizer) -> Optional[LogitsProcessorFunc]:
  17. request = _adapt_request_for_tool_use(request)
  18. if guided_decoding_backend == 'outlines':
  19. if HAS_TRITON:
  20. return await get_outlines_guided_decoding_logits_processor(
  21. request, tokenizer)
  22. else:
  23. pass
  24. if guided_decoding_backend == 'lm-format-enforcer':
  25. from aphrodite.modeling.guided_decoding.lm_format_enforcer_decoding import ( # noqa
  26. get_lm_format_enforcer_guided_decoding_logits_processor)
  27. return await get_lm_format_enforcer_guided_decoding_logits_processor(
  28. request, tokenizer)
  29. raise ValueError(
  30. f"Unknown guided decoding backend '{guided_decoding_backend}'. "
  31. "Must be one of 'outlines, 'lm-format-enforcer'")
  32. def get_local_guided_decoding_logits_processor(
  33. guided_decoding_backend: str, guided_options: GuidedDecodingRequest,
  34. tokenizer) -> Optional[LogitsProcessorFunc]:
  35. # request = _adapt_request_for_tool_use(request)
  36. if guided_decoding_backend == 'outlines':
  37. return get_local_outlines_guided_decoding_logits_processor(
  38. guided_options, tokenizer)
  39. if guided_decoding_backend == 'lm-format-enforcer':
  40. from aphrodite.modeling.guided_decoding.lm_format_enforcer_decoding import ( # noqa
  41. get_local_lm_format_enforcer_guided_decoding_logits_processor)
  42. return get_local_lm_format_enforcer_guided_decoding_logits_processor(
  43. guided_options, tokenizer)
  44. raise ValueError(
  45. f"Unknown guided decoding backend '{guided_decoding_backend}'. "
  46. "Must be one of 'outlines, 'lm-format-enforcer'")
  47. def _adapt_request_for_tool_use(request: Union[CompletionRequest,
  48. ChatCompletionRequest]):
  49. # the legacy completion API does not support tool use
  50. if type(request) is CompletionRequest:
  51. return request
  52. # user has chosen to not use any tool
  53. if request.tool_choice == "none":
  54. return request
  55. # user has chosen to use a named tool
  56. if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam:
  57. tool_name = request.tool_choice.function.name
  58. tools = {tool.function.name: tool.function for tool in request.tools}
  59. if tool_name not in tools:
  60. raise ValueError(
  61. f"Tool '{tool_name}' has not been passed in `tools`.")
  62. tool = tools[tool_name]
  63. request.guided_json = tool.parameters
  64. return request