serving_tokenization.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. from typing import List, Optional, Union
  2. from loguru import logger
  3. from aphrodite.common.config import ModelConfig
  4. from aphrodite.common.utils import random_uuid
  5. from aphrodite.endpoints.chat_utils import (apply_hf_chat_template,
  6. apply_mistral_chat_template,
  7. load_chat_template,
  8. parse_chat_messages_futures)
  9. from aphrodite.endpoints.logger import RequestLogger
  10. # yapf conflicts with isort for this block
  11. # yapf: disable
  12. from aphrodite.endpoints.openai.protocol import (DetokenizeRequest,
  13. DetokenizeResponse,
  14. ErrorResponse,
  15. TokenizeChatRequest,
  16. TokenizeRequest,
  17. TokenizeResponse)
  18. # yapf: enable
  19. from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
  20. OpenAIServing)
  21. from aphrodite.engine.protocol import EngineClient
  22. from aphrodite.transformers_utils.tokenizer import MistralTokenizer
  23. class OpenAIServingTokenization(OpenAIServing):
  24. def __init__(
  25. self,
  26. engine_client: EngineClient,
  27. model_config: ModelConfig,
  28. served_model_names: List[str],
  29. *,
  30. lora_modules: Optional[List[LoRAModulePath]],
  31. request_logger: Optional[RequestLogger],
  32. chat_template: Optional[str],
  33. ):
  34. super().__init__(engine_client=engine_client,
  35. model_config=model_config,
  36. served_model_names=served_model_names,
  37. lora_modules=lora_modules,
  38. prompt_adapters=None,
  39. request_logger=request_logger)
  40. # If this is None we use the tokenizer's default chat template
  41. # the list of commonly-used chat template names for HF named templates
  42. hf_chat_templates: List[str] = ['default', 'tool_use']
  43. self.chat_template = chat_template \
  44. if chat_template in hf_chat_templates \
  45. else load_chat_template(chat_template)
  46. async def create_tokenize(
  47. self,
  48. request: TokenizeRequest,
  49. ) -> Union[TokenizeResponse, ErrorResponse]:
  50. error_check_ret = await self._check_model(request)
  51. if error_check_ret is not None:
  52. return error_check_ret
  53. request_id = f"tokn-{random_uuid()}"
  54. (
  55. lora_request,
  56. prompt_adapter_request,
  57. ) = self._maybe_get_adapters(request)
  58. tokenizer = await self.engine_client.get_tokenizer(lora_request)
  59. prompt: Union[str, List[int]]
  60. if isinstance(request, TokenizeChatRequest):
  61. model_config = self.model_config
  62. conversation, mm_data_future = parse_chat_messages_futures(
  63. request.messages, model_config, tokenizer)
  64. mm_data = await mm_data_future
  65. if mm_data:
  66. logger.warning(
  67. "Multi-modal inputs are ignored during tokenization")
  68. if isinstance(tokenizer, MistralTokenizer):
  69. prompt = apply_mistral_chat_template(
  70. tokenizer,
  71. messages=request.messages,
  72. chat_template=self.chat_template,
  73. add_generation_prompt=request.add_generation_prompt,
  74. )
  75. else:
  76. prompt = apply_hf_chat_template(
  77. tokenizer,
  78. conversation=conversation,
  79. chat_template=self.chat_template,
  80. add_generation_prompt=request.add_generation_prompt,
  81. )
  82. else:
  83. prompt = request.prompt
  84. self._log_inputs(request_id,
  85. prompt,
  86. params=None,
  87. lora_request=lora_request,
  88. prompt_adapter_request=prompt_adapter_request)
  89. # Silently ignore prompt adapter since it does not affect tokenization
  90. prompt_input = self._tokenize_prompt_input(
  91. request,
  92. tokenizer,
  93. prompt,
  94. add_special_tokens=request.add_special_tokens,
  95. )
  96. input_ids = prompt_input["prompt_token_ids"]
  97. return TokenizeResponse(tokens=input_ids,
  98. count=len(input_ids),
  99. max_model_len=self.max_model_len)
  100. async def create_detokenize(
  101. self,
  102. request: DetokenizeRequest,
  103. ) -> Union[DetokenizeResponse, ErrorResponse]:
  104. error_check_ret = await self._check_model(request)
  105. if error_check_ret is not None:
  106. return error_check_ret
  107. request_id = f"tokn-{random_uuid()}"
  108. (
  109. lora_request,
  110. prompt_adapter_request,
  111. ) = self._maybe_get_adapters(request)
  112. tokenizer = await self.engine_client.get_tokenizer(lora_request)
  113. self._log_inputs(request_id,
  114. request.tokens,
  115. params=None,
  116. lora_request=lora_request,
  117. prompt_adapter_request=prompt_adapter_request)
  118. if prompt_adapter_request is not None:
  119. raise NotImplementedError("Prompt adapter is not supported "
  120. "for tokenization")
  121. prompt_input = self._tokenize_prompt_input(
  122. request,
  123. tokenizer,
  124. request.tokens,
  125. )
  126. input_text = prompt_input["prompt"]
  127. return DetokenizeResponse(prompt=input_text)