Преглед изворни кода

Refactor prompt processing (#605)

* wip

* finish up the refactor
AlpinDale пре 7 месеци
родитељ
комит
1ab2dad198

+ 13 - 24
aphrodite/endpoints/llm.py

@@ -10,8 +10,7 @@ from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.utils import Counter, deprecate_kwargs
 from aphrodite.engine.aphrodite_engine import AphroditeEngine
 from aphrodite.engine.args_tools import EngineArgs
-from aphrodite.inputs import (PromptInputs, PromptStrictInputs, TextPrompt,
-                              TextTokensPrompt, TokensPrompt,
+from aphrodite.inputs import (PromptInputs, TextPrompt, TokensPrompt,
                               parse_and_batch_prompt)
 from aphrodite.lora.request import LoRARequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
@@ -233,7 +232,7 @@ class LLM:
     @overload
     def generate(
         self,
-        inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
+        inputs: Union[PromptInputs, Sequence[PromptInputs]],
         /,  # We may enable `inputs` keyword after removing the old API
         *,
         sampling_params: Optional[Union[SamplingParams,
@@ -250,7 +249,7 @@ class LLM:
                       "instead.")
     def generate(
         self,
-        prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
+        prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
                        Optional[Union[str, List[str]]]] = None,
         sampling_params: Optional[Union[SamplingParams,
                                         Sequence[SamplingParams]]] = None,
@@ -297,9 +296,7 @@ class LLM:
                 prompt_token_ids=prompt_token_ids,
             )
         else:
-            inputs = cast(
-                Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
-                prompts)
+            inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
 
         if sampling_params is None:
             # Use default sampling params.
@@ -379,7 +376,7 @@ class LLM:
     @overload
     def encode(
         self,
-        inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
+        inputs: Union[PromptInputs, Sequence[PromptInputs]],
         /,  # We may enable `inputs` keyword after removing the old API
         *,
         pooling_params: Optional[Union[PoolingParams,
@@ -396,7 +393,7 @@ class LLM:
                       "instead.")
     def encode(
         self,
-        prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
+        prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
                        Optional[Union[str, List[str]]]] = None,
         pooling_params: Optional[Union[PoolingParams,
                                        Sequence[PoolingParams]]] = None,
@@ -414,7 +411,7 @@ class LLM:
         Args:
             inputs: The inputs to the LLM. You may pass a sequence of inputs for
                 batch inference. See
-                :class:`~aphrodite.inputs.PromptStrictInputs`
+                :class:`~aphrodite.inputs.PromptInputs`
                 for more details about the format of each input.
             pooling_params: The pooling parameters for pooling. If None, we
                 use the default pooling parameters.
@@ -443,9 +440,7 @@ class LLM:
                 prompt_token_ids=prompt_token_ids,
             )
         else:
-            inputs = cast(
-                Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
-                prompts)
+            inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
 
         if pooling_params is None:
             # Use default pooling params.
@@ -493,17 +488,11 @@ class LLM:
         inputs: List[PromptInputs] = []
         for i in range(num_requests):
             if prompts is not None:
-                if prompt_token_ids is not None:
-                    item = TextTokensPrompt(
-                        prompt=prompts[i],
-                        prompt_token_ids=prompt_token_ids[i])
-                else:
-                    item = TextPrompt(prompt=prompts[i])
+                item = TextPrompt(prompt=prompts[i])
+            elif prompt_token_ids is not None:
+                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
             else:
-                if prompt_token_ids is not None:
-                    item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
-                else:
-                    raise AssertionError
+                raise AssertionError
 
             inputs.append(item)
 
@@ -511,7 +500,7 @@ class LLM:
 
     def _validate_and_add_requests(
         self,
-        inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
+        inputs: Union[PromptInputs, Sequence[PromptInputs]],
         params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                       Sequence[PoolingParams]],
         lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],

+ 40 - 0
aphrodite/endpoints/logger.py

@@ -0,0 +1,40 @@
+from typing import List, Optional, Union
+
+from loguru import logger
+
+from aphrodite.common.pooling_params import PoolingParams
+from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.lora.request import LoRARequest
+from aphrodite.prompt_adapter.request import PromptAdapterRequest
+
+
+class RequestLogger:
+
+    def __init__(self, *, max_log_len: Optional[int]) -> None:
+        super().__init__()
+
+        self.max_log_len = max_log_len
+
+    def log_inputs(
+        self,
+        request_id: str,
+        prompt: Optional[str],
+        prompt_token_ids: Optional[List[int]],
+        params: Optional[Union[SamplingParams, PoolingParams]],
+        lora_request: Optional[LoRARequest],
+        prompt_adapter_request: Optional[PromptAdapterRequest],
+    ) -> None:
+        max_log_len = self.max_log_len
+        if max_log_len is not None:
+            if prompt is not None:
+                prompt = prompt[:max_log_len]
+
+            if prompt_token_ids is not None:
+                prompt_token_ids = prompt_token_ids[:max_log_len]
+
+        logger.info(f"Received request {request_id}: "
+                    f"prompt: {prompt}, "
+                    f"params: {params}, "
+                    f"prompt_token_ids: {prompt_token_ids}, "
+                    f"lora_request: {lora_request}, "
+                    f"prompt_adapter_request: {prompt_adapter_request}.")

+ 36 - 11
aphrodite/endpoints/openai/api_server.py

@@ -23,6 +23,7 @@ from starlette.routing import Mount
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import _SAMPLING_EPS, SamplingParams
 from aphrodite.common.utils import FlexibleArgumentParser, random_uuid
+from aphrodite.endpoints.logger import RequestLogger
 from aphrodite.endpoints.openai.args import make_arg_parser
 from aphrodite.endpoints.openai.protocol import (
     ChatCompletionRequest, ChatCompletionResponse, CompletionRequest,
@@ -515,24 +516,48 @@ def run_server(args, llm_engine=None):
         # When using single Aphrodite without engine_use_ray
         model_config = asyncio.run(engine.get_model_config())
 
+    if args.disable_log_requests:
+        request_logger = None
+    else:
+        request_logger = RequestLogger(max_log_len=args.max_log_len)
+
     global openai_serving_chat
     global openai_serving_completion
     global openai_serving_embedding
     global openai_serving_tokenization
 
-    openai_serving_chat = OpenAIServingChat(engine, model_config,
-                                            served_model_names,
-                                            args.response_role,
-                                            args.lora_modules,
-                                            args.chat_template)
+    openai_serving_chat = OpenAIServingChat(
+        engine,
+        model_config,
+        served_model_names,
+        args.response_role,
+        lora_modules=args.lora_modules,
+        prompt_adapters=args.prompt_adapters,
+        request_logger=request_logger,
+        chat_template=args.chat_template,
+    )
     openai_serving_completion = OpenAIServingCompletion(
-        engine, model_config, served_model_names, args.lora_modules,
-        args.prompt_adapters)
-    openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
-                                                      served_model_names)
+        engine,
+        model_config,
+        served_model_names,
+        lora_modules=args.lora_modules,
+        prompt_adapters=args.prompt_adapters,
+        request_logger=request_logger,
+    )
+    openai_serving_embedding = OpenAIServingEmbedding(
+        engine,
+        model_config,
+        served_model_names,
+        request_logger=request_logger,
+    )
     openai_serving_tokenization = OpenAIServingTokenization(
-        engine, model_config, served_model_names, args.lora_modules,
-        args.chat_template)
+        engine,
+        model_config,
+        served_model_names,
+        lora_modules=args.lora_modules,
+        request_logger=request_logger,
+        chat_template=args.chat_template,
+    )
     app.root_path = args.root_path
 
     tokenizer = get_tokenizer(

+ 6 - 0
aphrodite/endpoints/openai/args.py

@@ -134,6 +134,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
         "--launch-kobold-api",
         action="store_true",
         help="Launch the Kobold API server alongside the OpenAI server")
+    parser.add_argument("--max-log-len",
+                        type=int,
+                        default=0,
+                        help="Max number of prompt characters or prompt "
+                        "ID numbers being printed in log."
+                        "\n\nDefault: 0")
 
     parser = AsyncEngineArgs.add_cli_args(parser)
     return parser

+ 19 - 15
aphrodite/endpoints/openai/protocol.py

@@ -381,15 +381,11 @@ class CompletionRequest(OpenAIBaseModel):
     skip_special_tokens: Optional[bool] = True
     spaces_between_special_tokens: Optional[bool] = True
     truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
+    include_stop_str_in_output: Optional[bool] = False
+    add_special_tokens: Optional[bool] = False
     # doc: end-completion-sampling-params
 
     # doc: begin-completion-extra-params
-    include_stop_str_in_output: Optional[bool] = Field(
-        default=False,
-        description=(
-            "Whether to include the stop string in the output. "
-            "This is only applied when the stop or stop_token_ids is set."),
-    )
     response_format: Optional[ResponseFormat] = Field(
         default=None,
         description=
@@ -521,7 +517,7 @@ class CompletionRequest(OpenAIBaseModel):
         return data
 
 
-class EmbeddingRequest(BaseModel):
+class EmbeddingRequest(OpenAIBaseModel):
     # Ordered by official OpenAI API documentation
     # https://platform.openai.com/docs/api-reference/embeddings
     model: str
@@ -593,13 +589,13 @@ class CompletionStreamResponse(OpenAIBaseModel):
     usage: Optional[UsageInfo] = Field(default=None)
 
 
-class EmbeddingResponseData(BaseModel):
+class EmbeddingResponseData(OpenAIBaseModel):
     index: int
     object: str = "embedding"
     embedding: Union[List[float], str]
 
 
-class EmbeddingResponse(BaseModel):
+class EmbeddingResponse(OpenAIBaseModel):
     id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
     object: str = "list"
     created: int = Field(default_factory=lambda: int(time.time()))
@@ -698,8 +694,8 @@ class BatchRequestInput(OpenAIBaseModel):
     # /v1/chat/completions is supported.
     url: str
 
-    # The parameteters of the request.
-    body: Union[ChatCompletionRequest, ]
+    # The parameters of the request.
+    body: ChatCompletionRequest
 
 
 class BatchResponseData(OpenAIBaseModel):
@@ -731,12 +727,20 @@ class BatchRequestOutput(OpenAIBaseModel):
     error: Optional[Any]
 
 
-class TokenizeRequest(OpenAIBaseModel):
-    model: Optional[str]
+class TokenizeCompletionRequest(OpenAIBaseModel):
+    model: str
+    prompt: str
+    add_special_tokens: bool = Field(default=True)
+
+
+class TokenizeChatRequest(OpenAIBaseModel):
+    model: str
+    messages: List[ChatCompletionMessageParam]
     add_generation_prompt: bool = Field(default=True)
     add_special_tokens: bool = Field(default=False)
-    prompt: Optional[str] = Field(default=None)
-    messages: Optional[List[ChatCompletionMessageParam]] = Field(default=None)
+
+
+TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
 
 
 class TokenizeResponse(OpenAIBaseModel):

+ 16 - 0
aphrodite/endpoints/openai/run_batch.py

@@ -6,6 +6,7 @@ import aiohttp
 from loguru import logger
 
 from aphrodite.common.utils import FlexibleArgumentParser, random_uuid
+from aphrodite.endpoints.logger import RequestLogger
 from aphrodite.endpoints.openai.protocol import (BatchRequestInput,
                                                  BatchRequestOutput,
                                                  BatchResponseData,
@@ -42,6 +43,12 @@ def parse_args():
                         default="assistant",
                         help="The role name to return if "
                         "`request.add_generation_prompt=true`.")
+    parser.add_argument("--max-log-len",
+                        type=int,
+                        default=0,
+                        help="Max number of prompt characters or prompt "
+                        "ID numbers being printed in log."
+                        "\n\nDefault: 0")
 
     parser = AsyncEngineArgs.add_cli_args(parser)
     return parser.parse_args()
@@ -111,11 +118,20 @@ async def main(args):
     # When using single Aphrodite without engine_use_ray
     model_config = await engine.get_model_config()
 
+    if args.disable_log_requests:
+        request_logger = None
+    else:
+        request_logger = RequestLogger(max_log_len=args.max_log_len)
+
     openai_serving_chat = OpenAIServingChat(
         engine,
         model_config,
         served_model_names,
         args.response_role,
+        lora_modules=None,
+        prompt_adapters=None,
+        request_logger=request_logger,
+        chat_template=None,
     )
 
     # Submit all requests in the file to the engine "concurrently".

+ 62 - 38
aphrodite/endpoints/openai/serving_chat.py

@@ -15,6 +15,7 @@ from aphrodite.common.utils import random_uuid
 from aphrodite.endpoints.chat_utils import (ConversationMessage,
                                             load_chat_template,
                                             parse_chat_message_content)
+from aphrodite.endpoints.logger import RequestLogger
 from aphrodite.endpoints.openai.protocol import (
     ChatCompletionLogProb, ChatCompletionLogProbs,
     ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
@@ -23,7 +24,8 @@ from aphrodite.endpoints.openai.protocol import (
     ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
     FunctionCall, ToolCall, UsageInfo)
 from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
-                                                       OpenAIServing)
+                                                       OpenAIServing,
+                                                       PromptAdapterPath)
 from aphrodite.engine.async_aphrodite import AsyncAphrodite
 from aphrodite.inputs import PromptInputs
 from aphrodite.modeling.guided_decoding import \
@@ -33,17 +35,24 @@ from aphrodite.multimodal import MultiModalDataDict
 
 class OpenAIServingChat(OpenAIServing):
 
-    def __init__(self,
-                 engine: AsyncAphrodite,
-                 model_config: ModelConfig,
-                 served_model_names: List[str],
-                 response_role: str,
-                 lora_modules: Optional[List[LoRAModulePath]] = None,
-                 chat_template: Optional[str] = None):
+    def __init__(
+        self,
+        engine: AsyncAphrodite,
+        model_config: ModelConfig,
+        served_model_names: List[str],
+        response_role: str,
+        *,
+        lora_modules: Optional[List[LoRAModulePath]],
+        prompt_adapters: Optional[List[PromptAdapterPath]],
+        request_logger: Optional[RequestLogger],
+        chat_template: Optional[str],
+    ):
         super().__init__(engine=engine,
                          model_config=model_config,
                          served_model_names=served_model_names,
-                         lora_modules=lora_modules)
+                         lora_modules=lora_modules,
+                         prompt_adapters=prompt_adapters,
+                         request_logger=request_logger)
 
         self.response_role = response_role
         # If this is None we use the tokenizer's default chat template
@@ -69,14 +78,19 @@ class OpenAIServingChat(OpenAIServing):
             return error_check_ret
 
         try:
-            _, lora_request = self._maybe_get_adapter(request)
+            (
+                lora_request,
+                prompt_adapter_request,
+            ) = self._maybe_get_adapters(request)
+
+            model_config = self.model_config
             tokenizer = await self.engine.get_tokenizer(lora_request)
             conversation: List[ConversationMessage] = []
             mm_futures: List[Awaitable[MultiModalDataDict]] = []
 
             for msg in request.messages:
                 chat_parsed_result = parse_chat_message_content(
-                    msg, self.model_config, tokenizer)
+                    msg, model_config, tokenizer)
 
                 conversation.extend(chat_parsed_result.messages)
                 mm_futures.extend(chat_parsed_result.mm_futures)
@@ -110,14 +124,8 @@ class OpenAIServingChat(OpenAIServing):
             logger.error("Error in loading multi-modal data: {e}")
             return self.create_error_response(str(e))
 
-        request_id = f"cmpl-{random_uuid()}"
+        request_id = f"chat-{random_uuid()}"
         try:
-            # Tokenize/detokenize depending on prompt format (string/token list)
-            prompt_ids, prompt_text = await self._validate_prompt_and_tokenize(
-                request,
-                tokenizer,
-                prompt=prompt,
-                add_special_tokens=request.add_special_tokens)
             sampling_params = request.to_sampling_params()
             decoding_config = await self.engine.get_decoding_config()
             guided_decoding_backend = request.guided_decoding_backend \
@@ -131,22 +139,38 @@ class OpenAIServingChat(OpenAIServing):
                     sampling_params.logits_processors = []
                 sampling_params.logits_processors.append(
                     guided_decode_logits_processor)
+
+            prompt_inputs = self._tokenize_prompt_input(
+                request,
+                tokenizer,
+                prompt,
+                truncate_prompt_tokens=sampling_params.truncate_prompt_tokens,
+                add_special_tokens=request.add_special_tokens,
+            )
+
+            self._log_inputs(request_id,
+                             prompt_inputs,
+                             params=sampling_params,
+                             lora_request=lora_request,
+                             prompt_adapter_request=prompt_adapter_request)
+
+            engine_inputs: PromptInputs = {
+                "prompt_token_ids": prompt_inputs["prompt_token_ids"],
+            }
+            if mm_data is not None:
+                engine_inputs["multi_modal_data"] = mm_data
+
+            result_generator = self.engine.generate(
+                engine_inputs,
+                sampling_params,
+                request_id,
+                lora_request=lora_request,
+                prompt_adapter_request=prompt_adapter_request,
+            )
         except ValueError as e:
+            # TODO: Use an aphrodite-specific Validation Error
             return self.create_error_response(str(e))
 
-        inputs: PromptInputs = {
-            "prompt": prompt_text,
-            "prompt_token_ids": prompt_ids,
-        }
-        if mm_data:
-            inputs["multi_modal_data"] = mm_data
-
-        result_generator = self.engine.generate(
-            inputs,
-            sampling_params,
-            request_id,
-            lora_request,
-        )
         # Streaming response
         if request.stream:
             return self.chat_completion_stream_generator(
@@ -180,10 +204,10 @@ class OpenAIServingChat(OpenAIServing):
         first_iteration = True
 
         # Send response for each token for each request.n (index)
-        assert request.n is not None
-        previous_texts = [""] * request.n
-        previous_num_tokens = [0] * request.n
-        finish_reason_sent = [False] * request.n
+        num_choices = 1 if request.n is None else request.n
+        previous_texts = [""] * num_choices
+        previous_num_tokens = [0] * num_choices
+        finish_reason_sent = [False] * num_choices
         try:
             async for res in result_generator:
                 # We need to do it here, because if there are exceptions in
@@ -193,7 +217,7 @@ class OpenAIServingChat(OpenAIServing):
                     # Send first response for each request.n (index) with
                     # the role
                     role = self.get_chat_request_role(request)
-                    for i in range(request.n):
+                    for i in range(num_choices):
                         choice_data = ChatCompletionResponseStreamChoice(
                             index=i,
                             delta=DeltaMessage(role=role),
@@ -221,19 +245,19 @@ class OpenAIServingChat(OpenAIServing):
                             last_msg_content = conversation[-1]["content"]
 
                         if last_msg_content:
-                            for i in range(request.n):
+                            for i in range(num_choices):
                                 choice_data = (
                                     ChatCompletionResponseStreamChoice(
                                         index=i,
                                         delta=DeltaMessage(
                                             content=last_msg_content),
+                                        logprobs=None,
                                         finish_reason=None))
                                 chunk = ChatCompletionStreamResponse(
                                     id=request_id,
                                     object=chunk_object_type,
                                     created=created_time,
                                     choices=[choice_data],
-                                    logprobs=None,
                                     model=model_name)
                                 if (request.stream_options and
                                         request.stream_options.include_usage):

+ 59 - 54
aphrodite/endpoints/openai/serving_completions.py

@@ -2,7 +2,7 @@ import time
 from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
                     Optional)
 from typing import Sequence as GenericSequence
-from typing import Tuple
+from typing import Tuple, cast
 
 from fastapi import Request
 from transformers import PreTrainedTokenizer
@@ -11,6 +11,7 @@ from aphrodite.common.config import ModelConfig
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sequence import Logprob
 from aphrodite.common.utils import merge_async_iterators, random_uuid
+from aphrodite.endpoints.logger import RequestLogger
 # yapf conflicts with isort for this block
 # yapf: disable
 from aphrodite.endpoints.openai.protocol import (
@@ -31,40 +32,24 @@ TypeCreateLogProbsFn = Callable[
     [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
 
 
-def parse_prompt_format(prompt) -> Tuple[bool, list]:
-    # get the prompt, openai supports the following
-    # "a string, array of strings, array of tokens, or array of token arrays."
-    prompt_is_tokens = False
-    prompts = [prompt]  # case 1: a string
-    if isinstance(prompt, list):
-        if len(prompt) == 0:
-            raise ValueError("please provide at least one prompt")
-        elif isinstance(prompt[0], str):
-            prompt_is_tokens = False
-            prompts = prompt  # case 2: array of strings
-        elif isinstance(prompt[0], int):
-            prompt_is_tokens = True
-            prompts = [prompt]  # case 3: array of tokens
-        elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
-            prompt_is_tokens = True
-            prompts = prompt  # case 4: array of token arrays
-        else:
-            raise ValueError("prompt must be a string, array of strings, "
-                             "array of tokens, or array of token arrays")
-    return prompt_is_tokens, prompts
-
-
 class OpenAIServingCompletion(OpenAIServing):
 
-    def __init__(self, engine: AsyncAphrodite, model_config: ModelConfig,
-                 served_model_names: List[str],
-                 lora_modules: Optional[List[LoRAModulePath]],
-                 prompt_adapters: Optional[List[PromptAdapterPath]]):
+    def __init__(
+        self,
+        engine: AsyncAphrodite,
+        model_config: ModelConfig,
+        served_model_names: List[str],
+        *,
+        lora_modules: Optional[List[LoRAModulePath]],
+        prompt_adapters: Optional[List[PromptAdapterPath]],
+        request_logger: Optional[RequestLogger],
+    ):
         super().__init__(engine=engine,
                          model_config=model_config,
                          served_model_names=served_model_names,
                          lora_modules=lora_modules,
-                         prompt_adapters=prompt_adapters)
+                         prompt_adapters=prompt_adapters,
+                         request_logger=request_logger)
 
     async def create_completion(self, request: CompletionRequest,
                                 raw_request: Request):
@@ -93,12 +78,11 @@ class OpenAIServingCompletion(OpenAIServing):
         # Schedule the request and get the result generator.
         generators: List[AsyncIterator[RequestOutput]] = []
         try:
-            adapter_type, adapter_request = self._maybe_get_adapter(request)
-            lora_request, prompt_adapter_request = None, None
-            if adapter_type == 'LoRA':
-                lora_request, prompt_adapter_request = adapter_request, None
-            elif adapter_type == 'PromptAdapter':
-                lora_request, prompt_adapter_request = None, adapter_request
+            (
+                lora_request,
+                prompt_adapter_request,
+            ) = self._maybe_get_adapters(request)
+
             tokenizer = await self.engine.get_tokenizer(lora_request)
 
             sampling_params = request.to_sampling_params()
@@ -114,25 +98,30 @@ class OpenAIServingCompletion(OpenAIServing):
                     sampling_params.logits_processors = []
                 sampling_params.logits_processors.append(
                     guided_decode_logit_processor)
-            prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
 
-            for i, prompt in enumerate(prompts):
-                prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt"
-                prompt_formats = await self._validate_prompt_and_tokenize(
+            prompts = list(
+                self._tokenize_prompt_input_or_inputs(
                     request,
                     tokenizer,
+                    request.prompt,
                     truncate_prompt_tokens=sampling_params.
                     truncate_prompt_tokens,
-                    **{prompt_arg: prompt})
-                prompt_ids, prompt_text = prompt_formats
+                    add_special_tokens=request.add_special_tokens,
+                ))
+
+            for i, prompt_inputs in enumerate(prompts):
+                request_id_item = f"{request_id}-{i}"
+
+                self._log_inputs(request_id_item,
+                                 prompt_inputs,
+                                 params=sampling_params,
+                                 lora_request=lora_request,
+                                 prompt_adapter_request=prompt_adapter_request)
 
                 generator = self.engine.generate(
-                    {
-                        "prompt": prompt_text,
-                        "prompt_token_ids": prompt_ids
-                    },
+                    {"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
                     sampling_params,
-                    f"{request_id}-{i}",
+                    request_id_item,
                     lora_request=lora_request,
                     prompt_adapter_request=prompt_adapter_request,
                 )
@@ -172,9 +161,26 @@ class OpenAIServingCompletion(OpenAIServing):
                     await self.engine.abort(f"{request_id}-{i}")
                     return self.create_error_response("Client disconnected")
                 final_res_batch[i] = res
+
+            for i, final_res in enumerate(final_res_batch):
+                assert final_res is not None
+
+                # The output should contain the input text
+                # We did not pass it into vLLM engine to avoid being redundant
+                # with the inputs token IDs
+                if final_res.prompt is None:
+                    final_res.prompt = prompts[i]["prompt"]
+
+            final_res_batch_checked = cast(List[RequestOutput],
+                                           final_res_batch)
             response = self.request_output_to_completion_response(
-                final_res_batch, request, request_id, created_time, model_name,
-                tokenizer)
+                final_res_batch_checked,
+                request,
+                request_id,
+                created_time,
+                model_name,
+                tokenizer,
+            )
         except ValueError as e:
             # TODO: Use an aphrodite-specific Validation Error
             return self.create_error_response(str(e))
@@ -203,10 +209,10 @@ class OpenAIServingCompletion(OpenAIServing):
         num_prompts: int,
         tokenizer: PreTrainedTokenizer,
     ) -> AsyncGenerator[str, None]:
-        assert request.n is not None
-        previous_texts = [""] * request.n * num_prompts
-        previous_num_tokens = [0] * request.n * num_prompts
-        has_echoed = [False] * request.n * num_prompts
+        num_choices = 1 if request.n is None else request.n
+        previous_texts = [""] * num_choices * num_prompts
+        previous_num_tokens = [0] * num_choices * num_prompts
+        has_echoed = [False] * num_choices * num_prompts
 
         try:
             async for prompt_idx, res in result_generator:
@@ -217,7 +223,7 @@ class OpenAIServingCompletion(OpenAIServing):
                     raise StopAsyncIteration()
 
                 for output in res.outputs:
-                    i = output.index + prompt_idx * request.n
+                    i = output.index + prompt_idx * num_choices
                     # TODO: optimize the performance by avoiding full
                     # text O(n^2) sending.
 
@@ -327,7 +333,6 @@ class OpenAIServingCompletion(OpenAIServing):
         num_prompt_tokens = 0
         num_generated_tokens = 0
         for final_res in final_res_batch:
-            assert final_res is not None
             prompt_token_ids = final_res.prompt_token_ids
             prompt_logprobs = final_res.prompt_logprobs
             prompt_text = final_res.prompt

+ 52 - 23
aphrodite/endpoints/openai/serving_embedding.py

@@ -1,6 +1,6 @@
 import base64
 import time
-from typing import AsyncIterator, List, Optional, Tuple
+from typing import AsyncIterator, List, Optional, Tuple, cast
 
 import numpy as np
 from fastapi import Request
@@ -9,11 +9,11 @@ from loguru import logger
 from aphrodite.common.config import ModelConfig
 from aphrodite.common.outputs import EmbeddingRequestOutput
 from aphrodite.common.utils import merge_async_iterators, random_uuid
+from aphrodite.endpoints.logger import RequestLogger
 from aphrodite.endpoints.openai.protocol import (EmbeddingRequest,
                                                  EmbeddingResponse,
                                                  EmbeddingResponseData,
                                                  UsageInfo)
-from aphrodite.endpoints.openai.serving_completions import parse_prompt_format
 from aphrodite.endpoints.openai.serving_engine import OpenAIServing
 from aphrodite.engine.async_aphrodite import AsyncAphrodite
 
@@ -27,11 +27,11 @@ def request_output_to_embedding_response(
     data: List[EmbeddingResponseData] = []
     num_prompt_tokens = 0
     for idx, final_res in enumerate(final_res_batch):
-        assert final_res is not None
         prompt_token_ids = final_res.prompt_token_ids
         embedding = final_res.outputs.embedding
         if encoding_format == "base64":
-            embedding = base64.b64encode(np.array(embedding))
+            embedding_bytes = np.array(embedding).tobytes()
+            embedding = base64.b64encode(embedding_bytes).decode("utf-8")
         embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
         data.append(embedding_data)
 
@@ -53,12 +53,20 @@ def request_output_to_embedding_response(
 
 class OpenAIServingEmbedding(OpenAIServing):
 
-    def __init__(self, engine: AsyncAphrodite, model_config: ModelConfig,
-                 served_model_names: List[str]):
+    def __init__(
+        self,
+        engine: AsyncAphrodite,
+        model_config: ModelConfig,
+        served_model_names: List[str],
+        *,
+        request_logger: Optional[RequestLogger],
+    ):
         super().__init__(engine=engine,
                          model_config=model_config,
                          served_model_names=served_model_names,
-                         lora_modules=None)
+                         lora_modules=None,
+                         prompt_adapters=None,
+                         request_logger=request_logger)
         self._check_embedding_mode(model_config.embedding_mode)
 
     async def create_embedding(self, request: EmbeddingRequest,
@@ -79,30 +87,46 @@ class OpenAIServingEmbedding(OpenAIServing):
                 "dimensions is currently not supported")
 
         model_name = request.model
-        request_id = f"cmpl-{random_uuid()}"
+        request_id = f"embd-{random_uuid()}"
         created_time = int(time.monotonic())
 
         # Schedule the request and get the result generator.
-        generators = []
+        generators: List[AsyncIterator[EmbeddingRequestOutput]] = []
         try:
-            prompt_is_tokens, prompts = parse_prompt_format(request.input)
+            (
+                lora_request,
+                prompt_adapter_request,
+            ) = self._maybe_get_adapters(request)
+
+            tokenizer = await self.engine.get_tokenizer(lora_request)
             pooling_params = request.to_pooling_params()
 
-            tokenizer = await self.engine.get_tokenizer()
-            for i, prompt in enumerate(prompts):
-                prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt"
-                prompt_formats = await self._validate_prompt_and_tokenize(
-                    request, tokenizer, **{prompt_arg: prompt})
+            prompts = list(
+                self._tokenize_prompt_input_or_inputs(
+                    request,
+                    tokenizer,
+                    request.input,
+                ))
+
+            for i, prompt_inputs in enumerate(prompts):
+                request_id_item = f"{request_id}-{i}"
+
+                self._log_inputs(request_id_item,
+                                 prompt_inputs,
+                                 params=pooling_params,
+                                 lora_request=lora_request,
+                                 prompt_adapter_request=prompt_adapter_request)
 
-                prompt_ids, prompt_text = prompt_formats
+                if prompt_adapter_request is not None:
+                    raise NotImplementedError(
+                        "Prompt adapter is not supported "
+                        "for embedding models")
 
                 generator = self.engine.encode(
-                    {
-                        "prompt": prompt_text,
-                        "prompt_token_ids": prompt_ids
-                    },
+                    {"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
                     pooling_params,
-                    f"{request_id}-{i}",
+                    request_id_item,
+                    lora_request=lora_request,
                 )
 
                 generators.append(generator)
@@ -121,11 +145,16 @@ class OpenAIServingEmbedding(OpenAIServing):
                 if await raw_request.is_disconnected():
                     # Abort the request if the client disconnects.
                     await self.engine.abort(f"{request_id}-{i}")
-                    # TODO: Use an aphrodite-specific Validation Error
                     return self.create_error_response("Client disconnected")
                 final_res_batch[i] = res
+
+            for final_res in final_res_batch:
+                assert final_res is not None
+
+            final_res_batch_checked = cast(List[EmbeddingRequestOutput],
+                                           final_res_batch)
             response = request_output_to_embedding_response(
-                final_res_batch, request_id, created_time, model_name,
+                final_res_batch_checked, request_id, created_time, model_name,
                 encoding_format)
         except ValueError as e:
             # TODO: Use an aphrodite-specific Validation Error

+ 213 - 65
aphrodite/endpoints/openai/serving_engine.py

@@ -2,19 +2,31 @@ import json
 import pathlib
 from dataclasses import dataclass
 from http import HTTPStatus
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
 
 from pydantic import Field
-from transformers import PreTrainedTokenizer
+from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
 from typing_extensions import Annotated
 
 from aphrodite.common.config import ModelConfig
+from aphrodite.common.pooling_params import PoolingParams
+from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import Logprob
-from aphrodite.endpoints.openai.protocol import (
-    ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
-    EmbeddingRequest, ErrorResponse, ModelCard, ModelList, ModelPermission,
-    TokenizeRequest)
+from aphrodite.endpoints.logger import RequestLogger
+# yapf conflicts with isort here
+# yapf: disable
+from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
+                                                 CompletionRequest,
+                                                 DetokenizeRequest,
+                                                 EmbeddingRequest,
+                                                 ErrorResponse, ModelCard,
+                                                 ModelList, ModelPermission,
+                                                 TokenizeChatRequest,
+                                                 TokenizeCompletionRequest,
+                                                 TokenizeRequest)
+# yapf: enable
 from aphrodite.engine.async_aphrodite import AsyncAphrodite
+from aphrodite.inputs import parse_and_batch_prompt
 from aphrodite.lora.request import LoRARequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 
@@ -31,6 +43,17 @@ class LoRAModulePath:
     local_path: str
 
 
+AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
+                   EmbeddingRequest, TokenizeRequest]
+
+AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
+
+
+class TextTokensPrompt(TypedDict):
+    prompt: str
+    prompt_token_ids: List[int]
+
+
 class OpenAIServing:
 
     def __init__(
@@ -38,8 +61,10 @@ class OpenAIServing:
         engine: AsyncAphrodite,
         model_config: ModelConfig,
         served_model_names: List[str],
+        *,
         lora_modules: Optional[List[LoRAModulePath]],
-        prompt_adapters: Optional[List[PromptAdapterPath]] = None,
+        prompt_adapters: Optional[List[PromptAdapterPath]],
+        request_logger: Optional[RequestLogger],
     ):
         super().__init__()
 
@@ -73,6 +98,8 @@ class OpenAIServing:
                         prompt_adapter_local_path=prompt_adapter.local_path,
                         prompt_adapter_num_virtual_tokens=num_virtual_tokens))
 
+        self.request_logger = request_logger
+
     async def show_available_models(self) -> ModelList:
         """Show available models. Right now we only have one model."""
         model_cards = [
@@ -121,9 +148,8 @@ class OpenAIServing:
         return json_str
 
     async def _check_model(
-        self, request: Union[ChatCompletionRequest, CompletionRequest,
-                             DetokenizeRequest, EmbeddingRequest,
-                             TokenizeRequest]
+        self,
+        request: AnyRequest,
     ) -> Optional[ErrorResponse]:
         # only check these if it's not a Tokenizer/Detokenize Request
         if not isinstance(request, (TokenizeRequest, DetokenizeRequest)):
@@ -143,64 +169,65 @@ class OpenAIServing:
                 err_type="NotFoundError",
                 status_code=HTTPStatus.NOT_FOUND)
 
-    def _maybe_get_adapter(
-        self, request: Union[CompletionRequest, ChatCompletionRequest,
-                             EmbeddingRequest, TokenizeRequest,
-                             DetokenizeRequest]
-    ) -> Tuple[Optional[str], Optional[Union[LoRARequest,
-                                             PromptAdapterRequest]]]:
+    def _maybe_get_adapters(
+        self, request: AnyRequest
+    ) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
+            None, PromptAdapterRequest]]:
         if request.model in self.served_model_names:
             return None, None
         for lora in self.lora_requests:
             if request.model == lora.lora_name:
-                return 'LoRA', lora
+                return lora, None
         for prompt_adapter in self.prompt_adapter_requests:
             if request.model == prompt_adapter.prompt_adapter_name:
-                return 'PromptAdapter', prompt_adapter
+                return None, prompt_adapter
         # if _check_model has been called earlier, this will be unreachable
         raise ValueError(f"The model `{request.model}` does not exist.")
 
-    async def _validate_prompt_and_tokenize(
-            self,
-            request: Union[ChatCompletionRequest, CompletionRequest,
-                           DetokenizeRequest, EmbeddingRequest,
-                           TokenizeRequest],
-            tokenizer: "PreTrainedTokenizer",
-            prompt: Optional[str] = None,
-            prompt_ids: Optional[List[int]] = None,
-            truncate_prompt_tokens: Optional[Annotated[int,
-                                                       Field(ge=1)]] = None,
-            add_special_tokens: Optional[bool] = True
-    ) -> Tuple[List[int], str]:
-        if prompt and prompt_ids:
-            raise ValueError("Either prompt or prompt_ids should be provided.")
-        if (prompt and prompt_ids):
-            raise ValueError(
-                "Only one of prompt or prompt_ids should be provided.")
-
-        if prompt_ids is None:
-            # When using OpenAIServingChat for chat completions, for
-            # most models the special tokens (e.g., BOS) have already
-            # been added by the chat template. Therefore, we do not
-            # need to add them again.
-            # Set add_special_tokens to False (by default) to avoid
-            # adding the BOS tokens again.
-            tokenizer_kwargs: Dict[str, Any] = {
-                "add_special_tokens": add_special_tokens
-            }
-            if truncate_prompt_tokens is not None:
-                tokenizer_kwargs.update({
-                    "truncation": True,
-                    "max_length": truncate_prompt_tokens,
-                })
-            input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids
-        elif truncate_prompt_tokens is not None:
-            input_ids = prompt_ids[-truncate_prompt_tokens:]
+    def _normalize_prompt_text_to_input(
+        self,
+        request: AnyRequest,
+        tokenizer: AnyTokenizer,
+        prompt: str,
+        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
+        add_special_tokens: bool,
+    ) -> TextTokensPrompt:
+        if truncate_prompt_tokens is None:
+            encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
         else:
+            encoded = tokenizer(prompt,
+                                add_special_tokens=add_special_tokens,
+                                truncation=True,
+                                max_length=truncate_prompt_tokens)
+
+        input_ids = encoded.input_ids
+
+        input_text = prompt
+
+        return self._validate_input(request, input_ids, input_text)
+
+    def _normalize_prompt_tokens_to_input(
+        self,
+        request: AnyRequest,
+        tokenizer: AnyTokenizer,
+        prompt_ids: List[int],
+        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
+    ) -> TextTokensPrompt:
+        if truncate_prompt_tokens is None:
             input_ids = prompt_ids
+        else:
+            input_ids = prompt_ids[-truncate_prompt_tokens:]
+
+        input_text = tokenizer.decode(input_ids)
+
+        return self._validate_input(request, input_ids, input_text)
 
-        input_text = prompt if prompt is not None else tokenizer.decode(
-            input_ids)
+    def _validate_input(
+        self,
+        request: AnyRequest,
+        input_ids: List[int],
+        input_text: str,
+    ) -> TextTokensPrompt:
         token_num = len(input_ids)
 
         # Note: EmbeddingRequest doesn't have max_tokens
@@ -210,13 +237,16 @@ class OpenAIServing:
                     f"This model's maximum context length is "
                     f"{self.max_model_len} tokens. However, you requested "
                     f"{token_num} tokens in the input for embedding "
-                    f"generation. Please reduce the length of the input.", )
-            return input_ids, input_text
+                    f"generation. Please reduce the length of the input.")
+            return TextTokensPrompt(prompt=input_text,
+                                    prompt_token_ids=input_ids)
 
         # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
         # and does not require model context length validation
-        if isinstance(request, (TokenizeRequest, DetokenizeRequest)):
-            return input_ids, input_text
+        if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
+                                DetokenizeRequest)):
+            return TextTokensPrompt(prompt=input_text,
+                                    prompt_token_ids=input_ids)
 
         if request.max_tokens is None:
             if token_num >= self.max_model_len:
@@ -224,7 +254,7 @@ class OpenAIServing:
                     f"This model's maximum context length is "
                     f"{self.max_model_len} tokens. However, you requested "
                     f"{token_num} tokens in the messages, "
-                    f"Please reduce the length of the messages.", )
+                    f"Please reduce the length of the messages.")
             request.max_tokens = self.max_model_len - token_num
 
         if token_num + request.max_tokens > self.max_model_len:
@@ -234,13 +264,131 @@ class OpenAIServing:
                 f"{request.max_tokens + token_num} tokens "
                 f"({token_num} in the messages, "
                 f"{request.max_tokens} in the completion). "
-                f"Please reduce the length of the messages or completion.", )
+                f"Please reduce the length of the messages or completion.")
+
+        return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
+
+    def _tokenize_prompt_input(
+        self,
+        request: AnyRequest,
+        tokenizer: AnyTokenizer,
+        prompt_input: Union[str, List[int]],
+        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
+        add_special_tokens: bool = True,
+    ) -> TextTokensPrompt:
+        """
+        A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
+        that assumes single input.
+        """
+        return next(
+            self._tokenize_prompt_inputs(
+                request,
+                tokenizer,
+                [prompt_input],
+                truncate_prompt_tokens=truncate_prompt_tokens,
+                add_special_tokens=add_special_tokens,
+            ))
+
+    def _tokenize_prompt_inputs(
+        self,
+        request: AnyRequest,
+        tokenizer: AnyTokenizer,
+        prompt_inputs: Iterable[Union[str, List[int]]],
+        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
+        add_special_tokens: bool = True,
+    ) -> Iterator[TextTokensPrompt]:
+        """
+        A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
+        that assumes multiple inputs.
+        """
+        for text in prompt_inputs:
+            if isinstance(text, str):
+                yield self._normalize_prompt_text_to_input(
+                    request,
+                    tokenizer,
+                    prompt=text,
+                    truncate_prompt_tokens=truncate_prompt_tokens,
+                    add_special_tokens=add_special_tokens,
+                )
+            else:
+                yield self._normalize_prompt_tokens_to_input(
+                    request,
+                    tokenizer,
+                    prompt_ids=text,
+                    truncate_prompt_tokens=truncate_prompt_tokens,
+                )
+
+    def _tokenize_prompt_input_or_inputs(
+        self,
+        request: AnyRequest,
+        tokenizer: AnyTokenizer,
+        input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
+        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
+        add_special_tokens: bool = True,
+    ) -> Iterator[TextTokensPrompt]:
+        """
+        Tokenize/detokenize depending on the input format.
+        According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
+        , each input can be a string or array of tokens. Note that each request
+        can pass one or more inputs.
+        """
+        for prompt_input in parse_and_batch_prompt(input_or_inputs):
+            # Although our type checking is based on mypy,
+            # VSCode Pyright extension should still work properly
+            # "is True" is required for Pyright to perform type narrowing
+            # See: https://github.com/microsoft/pyright/issues/7672
+            if prompt_input["is_tokens"] is False:
+                yield self._normalize_prompt_text_to_input(
+                    request,
+                    tokenizer,
+                    prompt=prompt_input["content"],
+                    truncate_prompt_tokens=truncate_prompt_tokens,
+                    add_special_tokens=add_special_tokens,
+                )
+            else:
+                yield self._normalize_prompt_tokens_to_input(
+                    request,
+                    tokenizer,
+                    prompt_ids=prompt_input["content"],
+                    truncate_prompt_tokens=truncate_prompt_tokens,
+                )
+
+    def _log_inputs(
+        self,
+        request_id: str,
+        inputs: Union[str, List[int], TextTokensPrompt],
+        params: Optional[Union[SamplingParams, PoolingParams]],
+        lora_request: Optional[LoRARequest],
+        prompt_adapter_request: Optional[PromptAdapterRequest],
+    ) -> None:
+        if self.request_logger is None:
+            return
+
+        if isinstance(inputs, str):
+            prompt = inputs
+            prompt_token_ids = None
+        elif isinstance(inputs, list):
+            prompt = None
+            prompt_token_ids = inputs
         else:
-            return input_ids, input_text
+            prompt = inputs["prompt"]
+            prompt_token_ids = inputs["prompt_token_ids"]
+
+        self.request_logger.log_inputs(
+            request_id,
+            prompt,
+            prompt_token_ids,
+            params=params,
+            lora_request=lora_request,
+            prompt_adapter_request=prompt_adapter_request,
+        )
 
     @staticmethod
-    def _get_decoded_token(logprob: Logprob, token_id: int,
-                           tokenizer: PreTrainedTokenizer) -> str:
+    def _get_decoded_token(
+        logprob: Logprob,
+        token_id: int,
+        tokenizer: AnyTokenizer,
+    ) -> str:
         if logprob.decoded_token is not None:
             return logprob.decoded_token
         return tokenizer.decode(token_id)

+ 78 - 27
aphrodite/endpoints/openai/serving_tokenization.py

@@ -1,13 +1,20 @@
-from typing import List, Optional
+from typing import List, Optional, Union
 
 from aphrodite.common.config import ModelConfig
+from aphrodite.common.utils import random_uuid
 from aphrodite.endpoints.chat_utils import (ConversationMessage,
                                             load_chat_template,
                                             parse_chat_message_content)
+from aphrodite.endpoints.logger import RequestLogger
+# yapf conflicts with isort
+# yapf: disable
 from aphrodite.endpoints.openai.protocol import (DetokenizeRequest,
                                                  DetokenizeResponse,
+                                                 ErrorResponse,
+                                                 TokenizeChatRequest,
                                                  TokenizeRequest,
                                                  TokenizeResponse)
+# yapf: enable
 from aphrodite.endpoints.openai.serving_engine import (LoRAModulePath,
                                                        OpenAIServing)
 from aphrodite.engine.async_aphrodite import AsyncAphrodite
@@ -15,70 +22,114 @@ from aphrodite.engine.async_aphrodite import AsyncAphrodite
 
 class OpenAIServingTokenization(OpenAIServing):
 
-    def __init__(self,
-                 engine: AsyncAphrodite,
-                 model_config: ModelConfig,
-                 served_model_names: List[str],
-                 lora_modules: Optional[List[LoRAModulePath]] = None,
-                 chat_template: Optional[str] = None):
+    def __init__(
+        self,
+        engine: AsyncAphrodite,
+        model_config: ModelConfig,
+        served_model_names: List[str],
+        *,
+        lora_modules: Optional[List[LoRAModulePath]],
+        request_logger: Optional[RequestLogger],
+        chat_template: Optional[str],
+    ):
         super().__init__(engine=engine,
                          model_config=model_config,
                          served_model_names=served_model_names,
-                         lora_modules=lora_modules)
+                         lora_modules=lora_modules,
+                         prompt_adapters=None,
+                         request_logger=request_logger)
 
         # If this is None we use the tokenizer's default chat template
         self.chat_template = load_chat_template(chat_template)
 
-    async def create_tokenize(self,
-                              request: TokenizeRequest) -> TokenizeResponse:
+    async def create_tokenize(
+        self,
+        request: TokenizeRequest,
+    ) -> Union[TokenizeResponse, ErrorResponse]:
         error_check_ret = await self._check_model(request)
         if error_check_ret is not None:
             return error_check_ret
 
-        if not (request.prompt or request.messages):
-            return self.create_error_response(
-                "Either `prompt` or `messages` should be provided.")
+        request_id = f"tokn-{random_uuid()}"
 
-        if (request.prompt and request.messages):
-            return self.create_error_response(
-                "Only one of `prompt` or `messages` should be provided.")
+        (
+            lora_request,
+            prompt_adapter_request,
+        ) = self._maybe_get_adapters(request)
 
-        _, lora_request = self._maybe_get_adapter(request)
         tokenizer = await self.engine.get_tokenizer(lora_request)
 
-        if request.messages:
+        if isinstance(request, TokenizeChatRequest):
+            model_config = self.model_config
+
             conversation: List[ConversationMessage] = []
 
             for message in request.messages:
-                result = parse_chat_message_content(message, self.model_config,
+                result = parse_chat_message_content(message, model_config,
                                                     tokenizer)
                 conversation.extend(result.messages)
 
-            request.prompt = tokenizer.apply_chat_template(
+            prompt = tokenizer.apply_chat_template(
                 add_generation_prompt=request.add_generation_prompt,
                 conversation=conversation,
                 tokenize=False,
                 chat_template=self.chat_template)
+            assert isinstance(prompt, str)
+        else:
+            prompt = request.prompt
+
+        self._log_inputs(request_id,
+                         prompt,
+                         params=None,
+                         lora_request=lora_request,
+                         prompt_adapter_request=prompt_adapter_request)
+
+        # Silently ignore prompt adapter since it does not affect tokenization
 
-        (input_ids, input_text) = await self._validate_prompt_and_tokenize(
+        prompt_input = self._tokenize_prompt_input(
             request,
             tokenizer,
-            prompt=request.prompt,
-            add_special_tokens=request.add_special_tokens)
+            prompt,
+            add_special_tokens=request.add_special_tokens,
+        )
+        input_ids = prompt_input["prompt_token_ids"]
 
         return TokenizeResponse(tokens=input_ids,
                                 count=len(input_ids),
                                 max_model_len=self.max_model_len)
 
     async def create_detokenize(
-            self, request: DetokenizeRequest) -> DetokenizeResponse:
+        self,
+        request: DetokenizeRequest,
+    ) -> Union[DetokenizeResponse, ErrorResponse]:
         error_check_ret = await self._check_model(request)
         if error_check_ret is not None:
             return error_check_ret
 
-        _, lora_request = self._maybe_get_adapter(request)
+        request_id = f"tokn-{random_uuid()}"
+
+        (
+            lora_request,
+            prompt_adapter_request,
+        ) = self._maybe_get_adapters(request)
+
         tokenizer = await self.engine.get_tokenizer(lora_request)
-        (input_ids, input_text) = await self._validate_prompt_and_tokenize(
-            request, tokenizer, prompt_ids=request.tokens)
+
+        self._log_inputs(request_id,
+                         request.tokens,
+                         params=None,
+                         lora_request=lora_request,
+                         prompt_adapter_request=prompt_adapter_request)
+
+        if prompt_adapter_request is not None:
+            raise NotImplementedError("Prompt adapter is not supported "
+                                      "for tokenization")
+
+        prompt_input = self._tokenize_prompt_input(
+            request,
+            tokenizer,
+            request.tokens,
+        )
+        input_text = prompt_input["prompt"]
 
         return DetokenizeResponse(prompt=input_text)

+ 0 - 7
aphrodite/engine/args_tools.py

@@ -895,7 +895,6 @@ class AsyncEngineArgs(EngineArgs):
 
     engine_use_ray: bool = False
     disable_log_requests: bool = False
-    max_log_len: int = 0
     uvloop: bool = False
 
     @staticmethod
@@ -910,12 +909,6 @@ class AsyncEngineArgs(EngineArgs):
         parser.add_argument('--disable-log-requests',
                             action='store_true',
                             help='Disable logging requests.')
-        parser.add_argument('--max-log-len',
-                            type=int,
-                            default=0,
-                            help='Max number of prompt characters or prompt '
-                            'ID numbers being printed in log.'
-                            '\n\nDefault: Unlimited')
         parser.add_argument(
             "--uvloop",
             action="store_true",

+ 15 - 34
aphrodite/engine/async_aphrodite.py

@@ -149,7 +149,10 @@ class RequestTracker:
             logger.info(f"Finished request {request_id}.")
         self.abort_request(request_id)
 
-    def add_request(self, request_id: str,
+    def add_request(self,
+                    request_id: str,
+                    *,
+                    verbose: bool = False,
                     **engine_add_request_kwargs) -> AsyncStream:
         """Add a request to be sent to the engine on the next background
         loop iteration."""
@@ -164,6 +167,9 @@ class RequestTracker:
 
         self.new_requests_event.set()
 
+        if verbose:
+            logger.info(f"Added request {request_id}.")
+
         return stream
 
     def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
@@ -295,14 +301,13 @@ class _AsyncAphrodite(AphroditeEngine):
         return self.input_processor(llm_inputs)
 
     async def add_request_async(
-            self,
-            request_id: str,
-            inputs: PromptInputs,
-            params: Union[SamplingParams, PoolingParams],
-            arrival_time: Optional[float] = None,
-            lora_request: Optional[LoRARequest] = None,
-            trace_headers: Optional[Dict[str, str]] = None,
-            prompt_adapter_request: Optional[PromptAdapterRequest] = None
+        self,
+        request_id: str,
+        inputs: PromptInputs,
+        params: Union[SamplingParams, PoolingParams],
+        arrival_time: Optional[float] = None,
+        lora_request: Optional[LoRARequest] = None,
+        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
     ) -> None:
         if lora_request is not None and not self.lora_config:
             raise ValueError(f"Got lora_request {lora_request} but LoRA is "
@@ -351,8 +356,6 @@ class AsyncAphrodite:
             async frontend will be executed in a separate process as the
             model workers.
         log_requests: Whether to log the requests.
-        max_log_len: Maximum number of prompt characters or prompt ID numbers
-            being printed in log.
         start_engine_loop: If True, the background task to run the engine
             will be automatically started in the generate call.
         *args: Arguments for AphroditeEngine.
@@ -366,13 +369,11 @@ class AsyncAphrodite:
                  engine_use_ray: bool,
                  *args,
                  log_requests: bool = True,
-                 max_log_len: int = 0,
                  start_engine_loop: bool = True,
                  **kwargs) -> None:
         self.worker_use_ray = worker_use_ray
         self.engine_use_ray = engine_use_ray
         self.log_requests = log_requests
-        self.max_log_len = max_log_len
         self.engine = self._init_engine(*args, **kwargs)
 
         self.background_loop: Optional[asyncio.Future] = None
@@ -466,7 +467,6 @@ class AsyncAphrodite:
             executor_class=executor_class,
             log_requests=not engine_args.disable_log_requests,
             log_stats=not engine_args.disable_log_stats,
-            max_log_len=engine_args.max_log_len,
             start_engine_loop=start_engine_loop,
             stat_loggers=stat_loggers,
         )
@@ -666,26 +666,6 @@ class AsyncAphrodite:
         lora_request: Optional[LoRARequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
     ) -> AsyncStream:
-        if self.log_requests:
-            if isinstance(inputs, str):
-                shortened_prompt = inputs
-                shortened_token_ids = None
-            else:
-                shortened_prompt = inputs.get("prompt")
-                shortened_token_ids = inputs.get("prompt_token_ids")
-
-            max_log_len = self.max_log_len
-            if max_log_len is not None:
-                if shortened_prompt is not None:
-                    shortened_prompt = shortened_prompt[:max_log_len]
-                if shortened_token_ids is not None:
-                    shortened_token_ids = shortened_token_ids[:max_log_len]
-
-            logger.info(f"Received request {request_id}: "
-                        f"prompt: {shortened_prompt!r}, "
-                        f"params: {params}, "
-                        f"prompt_token_ids: {shortened_token_ids}, "
-                        f"lora_request: {lora_request}.")
 
         if not self.is_running:
             if self.start_engine_loop:
@@ -702,6 +682,7 @@ class AsyncAphrodite:
 
         stream = self._request_tracker.add_request(
             request_id,
+            verbose=self.log_requests,
             inputs=inputs,
             params=params,
             arrival_time=arrival_time,

+ 3 - 4
aphrodite/inputs/__init__.py

@@ -1,6 +1,5 @@
 from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs,
-                   PromptStrictInputs, TextPrompt, TextTokensPrompt,
-                   TokensPrompt, parse_and_batch_prompt)
+                   TextPrompt, TokensPrompt, parse_and_batch_prompt)
 from .registry import InputContext, InputRegistry
 
 INPUT_REGISTRY = InputRegistry()
@@ -14,6 +13,6 @@ See also:
 
 __all__ = [
     "ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt",
-    "TokensPrompt", "TextTokensPrompt", "PromptStrictInputs", "PromptInputs",
-    "LLMInputs", "INPUT_REGISTRY", "InputContext", "InputRegistry"
+    "TokensPrompt", "PromptInputs", "LLMInputs", "INPUT_REGISTRY",
+    "InputContext", "InputRegistry"
 ]

+ 1 - 23
aphrodite/inputs/data.py

@@ -91,35 +91,13 @@ class TokensPrompt(TypedDict):
     """
 
 
-class TextTokensPrompt(TypedDict):
-    """It is assumed that :attr:`prompt` is consistent with
-    :attr:`prompt_token_ids`. This is currently used in
-    :class:`AsyncAphrodite` for logging both the text and token IDs."""
-
-    prompt: str
-    """The prompt text."""
-
-    prompt_token_ids: List[int]
-    """The token IDs of the prompt."""
-
-    multi_modal_data: NotRequired["MultiModalDataDict"]
-    """
-    Optional multi-modal data to pass to the model,
-    if the model supports it.
-    """
-
-
-PromptStrictInputs = Union[str, TextPrompt, TokensPrompt]
+PromptInputs = Union[str, TextPrompt, TokensPrompt]
 """
 The inputs to the LLM, which can take one of the following forms:
 - A text prompt (:class:`str` or :class:`TextPrompt`)
 - A tokenized prompt (:class:`TokensPrompt`)
 """
 
-PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt]
-"""Same as :const:`PromptStrictInputs` but additionally accepts
-:class:`TextTokensPrompt`."""
-
 
 class LLMInputs(TypedDict):
     """