serving_tokenization.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. from typing import List, Optional, Union
  2. from loguru import logger
  3. from aphrodite.common.config import ModelConfig
  4. from aphrodite.common.utils import random_uuid
  5. # yapf conflicts with isort
  6. # yapf: disable
  7. from aphrodite.endpoints.chat_utils import (apply_chat_template,
  8. load_chat_template,
  9. parse_chat_messages)
  10. from aphrodite.endpoints.logger import RequestLogger
  11. from aphrodite.endpoints.openai.protocol import (DetokenizeRequest,
  12. DetokenizeResponse,
  13. ErrorResponse,
  14. TokenizeChatRequest,
  15. TokenizeRequest,
  16. TokenizeResponse)
  17. # yapf: enable
  18. from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
  19. OpenAIServing)
  20. from aphrodite.engine.protocol import AsyncEngineClient
  21. class OpenAIServingTokenization(OpenAIServing):
  22. def __init__(
  23. self,
  24. async_engine_client: AsyncEngineClient,
  25. model_config: ModelConfig,
  26. served_model_names: List[str],
  27. *,
  28. lora_modules: Optional[List[LoRAModulePath]],
  29. request_logger: Optional[RequestLogger],
  30. chat_template: Optional[str],
  31. ):
  32. super().__init__(async_engine_client=async_engine_client,
  33. model_config=model_config,
  34. served_model_names=served_model_names,
  35. lora_modules=lora_modules,
  36. prompt_adapters=None,
  37. request_logger=request_logger)
  38. # If this is None we use the tokenizer's default chat template
  39. self.chat_template = load_chat_template(chat_template)
  40. async def create_tokenize(
  41. self,
  42. request: TokenizeRequest,
  43. ) -> Union[TokenizeResponse, ErrorResponse]:
  44. error_check_ret = await self._check_model(request)
  45. if error_check_ret is not None:
  46. return error_check_ret
  47. request_id = f"tokn-{random_uuid()}"
  48. (
  49. lora_request,
  50. prompt_adapter_request,
  51. ) = self._maybe_get_adapters(request)
  52. tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
  53. if isinstance(request, TokenizeChatRequest):
  54. model_config = self.model_config
  55. conversation, mm_futures = parse_chat_messages(
  56. request.messages, model_config, tokenizer)
  57. if mm_futures:
  58. logger.warning(
  59. "Multi-modal inputs are ignored during tokenization")
  60. prompt = apply_chat_template(
  61. tokenizer,
  62. conversation=conversation,
  63. chat_template=self.chat_template,
  64. add_generation_prompt=request.add_generation_prompt,
  65. )
  66. else:
  67. prompt = request.prompt
  68. self._log_inputs(request_id,
  69. prompt,
  70. params=None,
  71. lora_request=lora_request,
  72. prompt_adapter_request=prompt_adapter_request)
  73. # Silently ignore prompt adapter since it does not affect tokenization
  74. prompt_input = self._tokenize_prompt_input(
  75. request,
  76. tokenizer,
  77. prompt,
  78. add_special_tokens=request.add_special_tokens,
  79. )
  80. input_ids = prompt_input["prompt_token_ids"]
  81. return TokenizeResponse(tokens=input_ids,
  82. count=len(input_ids),
  83. max_model_len=self.max_model_len)
  84. async def create_detokenize(
  85. self,
  86. request: DetokenizeRequest,
  87. ) -> Union[DetokenizeResponse, ErrorResponse]:
  88. error_check_ret = await self._check_model(request)
  89. if error_check_ret is not None:
  90. return error_check_ret
  91. request_id = f"tokn-{random_uuid()}"
  92. (
  93. lora_request,
  94. prompt_adapter_request,
  95. ) = self._maybe_get_adapters(request)
  96. tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
  97. self._log_inputs(request_id,
  98. request.tokens,
  99. params=None,
  100. lora_request=lora_request,
  101. prompt_adapter_request=prompt_adapter_request)
  102. if prompt_adapter_request is not None:
  103. raise NotImplementedError("Prompt adapter is not supported "
  104. "for tokenization")
  105. prompt_input = self._tokenize_prompt_input(
  106. request,
  107. tokenizer,
  108. request.tokens,
  109. )
  110. input_text = prompt_input["prompt"]
  111. return DetokenizeResponse(prompt=input_text)