serving_tokenization.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. from typing import List, Optional, Union
  2. from loguru import logger
  3. from aphrodite.common.config import ModelConfig
  4. from aphrodite.common.utils import random_uuid
  5. # yapf conflicts with isort
  6. # yapf: disable
  7. from aphrodite.endpoints.chat_utils import (apply_chat_template,
  8. load_chat_template,
  9. parse_chat_messages_futures)
  10. from aphrodite.endpoints.logger import RequestLogger
  11. from aphrodite.endpoints.openai.protocol import (DetokenizeRequest,
  12. DetokenizeResponse,
  13. ErrorResponse,
  14. TokenizeChatRequest,
  15. TokenizeRequest,
  16. TokenizeResponse)
  17. # yapf: enable
  18. from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
  19. OpenAIServing)
  20. from aphrodite.engine.protocol import AsyncEngineClient
  21. class OpenAIServingTokenization(OpenAIServing):
  22. def __init__(
  23. self,
  24. async_engine_client: AsyncEngineClient,
  25. model_config: ModelConfig,
  26. served_model_names: List[str],
  27. *,
  28. lora_modules: Optional[List[LoRAModulePath]],
  29. request_logger: Optional[RequestLogger],
  30. chat_template: Optional[str],
  31. ):
  32. super().__init__(async_engine_client=async_engine_client,
  33. model_config=model_config,
  34. served_model_names=served_model_names,
  35. lora_modules=lora_modules,
  36. prompt_adapters=None,
  37. request_logger=request_logger)
  38. # If this is None we use the tokenizer's default chat template
  39. # the list of commonly-used chat template names for HF named templates
  40. hf_chat_templates: List[str] = ['default', 'tool_use']
  41. self.chat_template = chat_template \
  42. if chat_template in hf_chat_templates \
  43. else load_chat_template(chat_template)
  44. async def create_tokenize(
  45. self,
  46. request: TokenizeRequest,
  47. ) -> Union[TokenizeResponse, ErrorResponse]:
  48. error_check_ret = await self._check_model(request)
  49. if error_check_ret is not None:
  50. return error_check_ret
  51. request_id = f"tokn-{random_uuid()}"
  52. (
  53. lora_request,
  54. prompt_adapter_request,
  55. ) = self._maybe_get_adapters(request)
  56. tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
  57. if isinstance(request, TokenizeChatRequest):
  58. model_config = self.model_config
  59. conversation, mm_data_future = parse_chat_messages_futures(
  60. request.messages, model_config, tokenizer)
  61. mm_data = await mm_data_future
  62. if mm_data:
  63. logger.warning(
  64. "Multi-modal inputs are ignored during tokenization")
  65. prompt = apply_chat_template(
  66. tokenizer,
  67. conversation=conversation,
  68. chat_template=self.chat_template,
  69. add_generation_prompt=request.add_generation_prompt,
  70. )
  71. else:
  72. prompt = request.prompt
  73. self._log_inputs(request_id,
  74. prompt,
  75. params=None,
  76. lora_request=lora_request,
  77. prompt_adapter_request=prompt_adapter_request)
  78. # Silently ignore prompt adapter since it does not affect tokenization
  79. prompt_input = self._tokenize_prompt_input(
  80. request,
  81. tokenizer,
  82. prompt,
  83. add_special_tokens=request.add_special_tokens,
  84. )
  85. input_ids = prompt_input["prompt_token_ids"]
  86. return TokenizeResponse(tokens=input_ids,
  87. count=len(input_ids),
  88. max_model_len=self.max_model_len)
  89. async def create_detokenize(
  90. self,
  91. request: DetokenizeRequest,
  92. ) -> Union[DetokenizeResponse, ErrorResponse]:
  93. error_check_ret = await self._check_model(request)
  94. if error_check_ret is not None:
  95. return error_check_ret
  96. request_id = f"tokn-{random_uuid()}"
  97. (
  98. lora_request,
  99. prompt_adapter_request,
  100. ) = self._maybe_get_adapters(request)
  101. tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
  102. self._log_inputs(request_id,
  103. request.tokens,
  104. params=None,
  105. lora_request=lora_request,
  106. prompt_adapter_request=prompt_adapter_request)
  107. if prompt_adapter_request is not None:
  108. raise NotImplementedError("Prompt adapter is not supported "
  109. "for tokenization")
  110. prompt_input = self._tokenize_prompt_input(
  111. request,
  112. tokenizer,
  113. request.tokens,
  114. )
  115. input_text = prompt_input["prompt"]
  116. return DetokenizeResponse(prompt=input_text)