serving_tokenization.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. from typing import List, Optional, Union
  2. from loguru import logger
  3. from aphrodite.common.config import ModelConfig
  4. from aphrodite.common.utils import random_uuid
  5. from aphrodite.endpoints.chat_utils import (apply_hf_chat_template,
  6. apply_mistral_chat_template,
  7. load_chat_template,
  8. parse_chat_messages_futures)
  9. from aphrodite.endpoints.logger import RequestLogger
  10. # yapf conflicts with isort for this block
  11. # yapf: disable
  12. from aphrodite.endpoints.openai.protocol import (DetokenizeRequest,
  13. DetokenizeResponse,
  14. ErrorResponse,
  15. TokenizeChatRequest,
  16. TokenizeRequest,
  17. TokenizeResponse)
  18. # yapf: enable
  19. from aphrodite.endpoints.openai.serving_engine import (BaseModelPath,
  20. LoRAModulePath,
  21. OpenAIServing)
  22. from aphrodite.engine.protocol import EngineClient
  23. from aphrodite.transformers_utils.tokenizer import MistralTokenizer
  24. class OpenAIServingTokenization(OpenAIServing):
  25. def __init__(
  26. self,
  27. engine_client: EngineClient,
  28. model_config: ModelConfig,
  29. base_model_paths: List[BaseModelPath],
  30. *,
  31. lora_modules: Optional[List[LoRAModulePath]],
  32. request_logger: Optional[RequestLogger],
  33. chat_template: Optional[str],
  34. ):
  35. super().__init__(engine_client=engine_client,
  36. model_config=model_config,
  37. base_model_paths=base_model_paths,
  38. lora_modules=lora_modules,
  39. prompt_adapters=None,
  40. request_logger=request_logger)
  41. # If this is None we use the tokenizer's default chat template
  42. # the list of commonly-used chat template names for HF named templates
  43. hf_chat_templates: List[str] = ['default', 'tool_use']
  44. self.chat_template = chat_template \
  45. if chat_template in hf_chat_templates \
  46. else load_chat_template(chat_template)
  47. async def create_tokenize(
  48. self,
  49. request: TokenizeRequest,
  50. ) -> Union[TokenizeResponse, ErrorResponse]:
  51. error_check_ret = await self._check_model(request)
  52. if error_check_ret is not None:
  53. return error_check_ret
  54. request_id = f"tokn-{random_uuid()}"
  55. (
  56. lora_request,
  57. prompt_adapter_request,
  58. ) = self._maybe_get_adapters(request)
  59. tokenizer = await self.engine_client.get_tokenizer(lora_request)
  60. prompt: Union[str, List[int]]
  61. if isinstance(request, TokenizeChatRequest):
  62. model_config = self.model_config
  63. conversation, mm_data_future = parse_chat_messages_futures(
  64. request.messages, model_config, tokenizer)
  65. mm_data = await mm_data_future
  66. if mm_data:
  67. logger.warning(
  68. "Multi-modal inputs are ignored during tokenization")
  69. if isinstance(tokenizer, MistralTokenizer):
  70. prompt = apply_mistral_chat_template(
  71. tokenizer,
  72. messages=request.messages,
  73. chat_template=self.chat_template,
  74. add_generation_prompt=request.add_generation_prompt,
  75. )
  76. else:
  77. prompt = apply_hf_chat_template(
  78. tokenizer,
  79. conversation=conversation,
  80. chat_template=self.chat_template,
  81. add_generation_prompt=request.add_generation_prompt,
  82. )
  83. else:
  84. prompt = request.prompt
  85. self._log_inputs(request_id,
  86. prompt,
  87. params=None,
  88. lora_request=lora_request,
  89. prompt_adapter_request=prompt_adapter_request)
  90. # Silently ignore prompt adapter since it does not affect tokenization
  91. prompt_input = self._tokenize_prompt_input(
  92. request,
  93. tokenizer,
  94. prompt,
  95. add_special_tokens=request.add_special_tokens,
  96. )
  97. input_ids = prompt_input["prompt_token_ids"]
  98. return TokenizeResponse(tokens=input_ids,
  99. count=len(input_ids),
  100. max_model_len=self.max_model_len)
  101. async def create_detokenize(
  102. self,
  103. request: DetokenizeRequest,
  104. ) -> Union[DetokenizeResponse, ErrorResponse]:
  105. error_check_ret = await self._check_model(request)
  106. if error_check_ret is not None:
  107. return error_check_ret
  108. request_id = f"tokn-{random_uuid()}"
  109. (
  110. lora_request,
  111. prompt_adapter_request,
  112. ) = self._maybe_get_adapters(request)
  113. tokenizer = await self.engine_client.get_tokenizer(lora_request)
  114. self._log_inputs(request_id,
  115. request.tokens,
  116. params=None,
  117. lora_request=lora_request,
  118. prompt_adapter_request=prompt_adapter_request)
  119. if prompt_adapter_request is not None:
  120. raise NotImplementedError("Prompt adapter is not supported "
  121. "for tokenization")
  122. prompt_input = self._tokenize_prompt_input(
  123. request,
  124. tokenizer,
  125. request.tokens,
  126. )
  127. input_text = prompt_input["prompt"]
  128. return DetokenizeResponse(prompt=input_text)