Browse Source

refactor: consolidate prompt args to LLM engines

AlpinDale 7 months ago
parent
commit
90ceab32ff

+ 128 - 0
aphrodite/common/inputs.py

@@ -0,0 +1,128 @@
+from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence,
+                    TypedDict, Union, cast, overload)
+
+from typing_extensions import NotRequired
+
+if TYPE_CHECKING:
+    from aphrodite.common.sequence import MultiModalData
+
+
+class ParsedText(TypedDict):
+    content: str
+    is_tokens: Literal[False]
+
+
+class ParsedTokens(TypedDict):
+    content: List[int]
+    is_tokens: Literal[True]
+
+
+@overload
+def parse_and_batch_prompt(
+        prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
+    ...
+
+
+@overload
+def parse_and_batch_prompt(
+        prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
+    ...
+
+
+def parse_and_batch_prompt(
+    prompt: Union[str, List[str], List[int], List[List[int]]],
+) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
+    if isinstance(prompt, str):
+        # case 1: a string
+        return [ParsedText(content=prompt, is_tokens=False)]
+
+    if isinstance(prompt, list):
+        if len(prompt) == 0:
+            raise ValueError("please provide at least one prompt")
+
+        if isinstance(prompt[0], str):
+            # case 2: array of strings
+            return [
+                ParsedText(content=elem, is_tokens=False)
+                for elem in cast(List[str], prompt)
+            ]
+        if isinstance(prompt[0], int):
+            # case 3: array of tokens
+            elem = cast(List[int], prompt)
+            return [ParsedTokens(content=elem, is_tokens=True)]
+        if isinstance(prompt[0], list):
+            if len(prompt[0]) == 0:
+                raise ValueError("please provide at least one prompt")
+
+            if isinstance(prompt[0][0], int):
+                # case 4: array of token arrays
+                return [
+                    ParsedTokens(content=elem, is_tokens=True)
+                    for elem in cast(List[List[int]], prompt)
+                ]
+
+    raise ValueError("prompt must be a string, array of strings, "
+                     "array of tokens, or array of token arrays")
+
+
+class TextPrompt(TypedDict):
+    """Schema for a text prompt."""
+
+    prompt: str
+    """The input text to be tokenized before passing to the model."""
+
+    multi_modal_data: NotRequired["MultiModalData"]
+    """
+    Optional multi-modal data to pass to the model,
+    if the model supports it.
+    """
+
+
+class TokensPrompt(TypedDict):
+    """Schema for a tokenized prompt."""
+
+    prompt_token_ids: List[int]
+    """A list of token IDs to pass to the model."""
+
+    multi_modal_data: NotRequired["MultiModalData"]
+    """
+    Optional multi-modal data to pass to the model,
+    if the model supports it.
+    """
+
+
+class TextTokensPrompt(TypedDict):
+    """It is assumed that :attr:`prompt` is consistent with
+    :attr:`prompt_token_ids`. This is currently used in
+    :class:`AsyncLLMEngine` for logging both the text and token IDs."""
+
+    prompt: str
+    """The prompt text."""
+
+    prompt_token_ids: List[int]
+    """The token IDs of the prompt. If None, we use the
+    tokenizer to convert the prompts to token IDs."""
+
+    multi_modal_data: NotRequired["MultiModalData"]
+    """
+    Optional multi-modal data to pass to the model,
+    if the model supports it.
+    """
+
+
+PromptStrictInputs = Union[str, TextPrompt, TokensPrompt]
+"""
+The inputs to the LLM, which can take one of the following forms:
+- A text prompt (:class:`str` or :class:`TextPrompt`)
+- A tokenized prompt (:class:`TokensPrompt`)
+"""
+
+PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt]
+"""Same as :const:`PromptStrictInputs` but additionally accepts
+:class:`TextTokensPrompt`."""
+
+
+class LLMInputs(TypedDict):
+    prompt_token_ids: List[int]
+    prompt: Optional[str]
+    multi_modal_data: Optional["MultiModalData"]

+ 14 - 26
aphrodite/common/outputs.py

@@ -1,4 +1,5 @@
 from typing import List, Optional, Union
+from dataclasses import dataclass
 import time
 
 from aphrodite.common.sequence import (
@@ -11,6 +12,7 @@ from aphrodite.common.sequence import (
 from aphrodite.lora.request import LoRARequest
 
 
+@dataclass
 class CompletionOutput:
     """The output data of one completion output of a request.
 
@@ -29,25 +31,14 @@ class CompletionOutput:
         lora_request: The LoRA request that was used to generate the output.
     """
 
-    def __init__(
-        self,
-        index: int,
-        text: str,
-        token_ids: List[int],
-        cumulative_logprob: float,
-        logprobs: Optional[SampleLogprobs],
-        finish_reason: Optional[str] = None,
-        stop_reason: Union[int, str, None] = None,
-        lora_request: Optional[LoRARequest] = None,
-    ) -> None:
-        self.index = index
-        self.text = text
-        self.token_ids = token_ids
-        self.cumulative_logprob = cumulative_logprob
-        self.logprobs = logprobs
-        self.finish_reason = finish_reason
-        self.stop_reason = stop_reason
-        self.lora_request = lora_request
+    index: int
+    text: str
+    token_ids: List[int]
+    cumulative_logprob: float
+    logprobs: Optional[SampleLogprobs]
+    finish_reason: Optional[str] = None
+    stop_reason: Union[int, str, None] = None
+    lora_request: Optional[LoRARequest] = None
 
     def finished(self) -> bool:
         return self.finish_reason is not None
@@ -62,6 +53,7 @@ class CompletionOutput:
                 f"stop_reason={self.stop_reason})")
 
 
+@dataclass
 class EmbeddingOutput:
     """The output data of one completion output of a request.
     Args:
@@ -69,15 +61,11 @@ class EmbeddingOutput:
         length of vector depends on the model as listed in the embedding guide.
     """
 
-    def __init__(
-        self,
-        embedding: List[float],
-    ) -> None:
-        self.embedding = embedding
+    embedding: List[float]
 
     def __repr__(self) -> str:
         return (f"EmbeddingOutput("
-                f"embedding={len(self.embedding)}")
+                f"embedding={len(self.embedding)})")
 
 
 class RequestOutput:
@@ -97,7 +85,7 @@ class RequestOutput:
     def __init__(
         self,
         request_id: str,
-        prompt: str,
+        prompt: Optional[str],
         prompt_token_ids: List[int],
         prompt_logprobs: Optional[PromptLogprobs],
         outputs: List[CompletionOutput],

+ 25 - 9
aphrodite/common/sequence.py

@@ -6,6 +6,7 @@ from dataclasses import dataclass, field
 from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
 
 from aphrodite.common.block import LogicalTokenBlock
+from aphrodite.common.inputs import LLMInputs
 from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.lora.request import LoRARequest
@@ -220,25 +221,24 @@ class Sequence:
     def __init__(
         self,
         seq_id: int,
-        prompt: str,
-        prompt_token_ids: List[int],
+        inputs: LLMInputs,
         block_size: int,
         eos_token_id: Optional[int] = None,
         lora_request: Optional[LoRARequest] = None,
     ) -> None:
         self.seq_id = seq_id
-        self.prompt = prompt
+        self.inputs = inputs
         self.block_size = block_size
         self.eos_token_id = eos_token_id
         self.lora_request = lora_request
 
-        self.data: SequenceData = SequenceData(prompt_token_ids)
+        self.data = SequenceData(self.prompt_token_ids)
         self.output_logprobs: SampleLogprobs = []
         self.output_text = ""
 
         self.logical_token_blocks: List[LogicalTokenBlock] = []
         # Initialize the logical token blocks with the prompt token ids.
-        self._append_tokens_to_blocks(prompt_token_ids)
+        self._append_tokens_to_blocks(self.prompt_token_ids)
         self.status = SequenceStatus.WAITING
         self.stop_reason: Union[int, str, None] = None
 
@@ -248,6 +248,18 @@ class Sequence:
         # Input + output tokens
         self.tokens: Optional[List[str]] = None
 
+    @property
+    def prompt(self) -> Optional[str]:
+        return self.inputs["prompt"]
+
+    @property
+    def prompt_token_ids(self) -> List[int]:
+        return self.inputs["prompt_token_ids"]
+
+    @property
+    def multi_modal_data(self) -> Optional["MultiModalData"]:
+        return self.inputs["multi_modal_data"]
+
     @property
     def lora_int_id(self) -> int:
         return self.lora_request.lora_int_id if self.lora_request else 0
@@ -429,7 +441,6 @@ class SequenceGroup:
         arrival_time: float,
         sampling_params: Optional[SamplingParams] = None,
         lora_request: Optional[LoRARequest] = None,
-        multi_modal_data: Optional[MultiModalData] = None,
         embeddings: Optional[List[float]] = None,
         pooling_params: Optional[PoolingParams] = None,
     ) -> None:
@@ -444,12 +455,11 @@ class SequenceGroup:
         self.lora_request = lora_request
         self.prompt_logprobs: Optional[PromptLogprobs] = None
         self.state = SequenceGroupState()
-        self.multi_modal_data = multi_modal_data
         self.embeddings = embeddings
         self.pooling_params = pooling_params
 
     @property
-    def prompt(self) -> str:
+    def prompt(self) -> Optional[str]:
         # All sequences in the group should have the same prompt.
         # We use the prompt of an arbitrary sequence.
         return next(iter(self.seqs_dict.values())).prompt
@@ -458,7 +468,13 @@ class SequenceGroup:
     def prompt_token_ids(self) -> List[int]:
         # All sequences in the group should have the same prompt.
         # We use the prompt of an arbitrary sequence.
-        return next(iter(self.seqs_dict.values())).data.prompt_token_ids
+        return next(iter(self.seqs_dict.values())).prompt_token_ids
+
+    @property
+    def multi_modal_data(self) -> Optional[MultiModalData]:
+        # All sequences in the group should have the same multi-modal data.
+        # We use the multi-modal data of an arbitrary sequence.
+        return next(iter(self.seqs_dict.values())).multi_modal_data
 
     @property
     def lora_int_id(self) -> int:

+ 42 - 1
aphrodite/common/utils.py

@@ -11,7 +11,7 @@ import threading
 import uuid
 import warnings
 from collections import defaultdict
-from functools import lru_cache, partial
+from functools import lru_cache, partial, wraps
 from platform import uname
 from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
                     Hashable, List, Optional, OrderedDict, Tuple, TypeVar,
@@ -646,3 +646,44 @@ def enable_trace_function_call_for_thread() -> None:
                                 get_aphrodite_instance_id(), filename)
         os.makedirs(os.path.dirname(log_path), exist_ok=True)
         enable_trace_function_call(log_path)
+
+
+def identity(value: T) -> T:
+    return value
+
+
+F = TypeVar('F', bound=Callable[..., Any])
+
+
+def deprecate_kwargs(
+        *kws: str,
+        is_deprecated: Union[bool, Callable[[], bool]] = True,
+        additional_message: Optional[str] = None) -> Callable[[F], F]:
+    deprecated_kws = set(kws)
+
+    if not callable(is_deprecated):
+        is_deprecated = partial(identity, is_deprecated)
+
+    def wrapper(fn: F) -> F:
+
+        @wraps(fn)
+        def inner(*args, **kwargs):
+            if is_deprecated():
+                deprecated_kwargs = kwargs.keys() & deprecated_kws
+                if deprecated_kwargs:
+                    msg = (
+                        f"The keyword arguments {deprecated_kwargs} are "
+                        "deprecated and will be removed in a future update.")
+                    if additional_message is not None:
+                        msg += f" {additional_message}"
+
+                    warnings.warn(
+                        DeprecationWarning(msg),
+                        stacklevel=3,  # The inner function takes up one level
+                    )
+
+            return fn(*args, **kwargs)
+
+        return inner  # type: ignore
+
+    return wrapper

+ 331 - 116
aphrodite/endpoints/llm.py

@@ -1,14 +1,17 @@
-from typing import List, Optional, Union
+from contextlib import contextmanager
+from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
 
-import torch
 from tqdm import tqdm
 from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
 
+from aphrodite.common.inputs import (PromptInputs, PromptStrictInputs,
+                                     TextPrompt, TextTokensPrompt,
+                                     TokensPrompt, parse_and_batch_prompt)
 from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
 from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import MultiModalData
-from aphrodite.common.utils import Counter
+from aphrodite.common.utils import Counter, deprecate_kwargs
 from aphrodite.engine.aphrodite_engine import AphroditeEngine
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.lora.request import LoRARequest
@@ -24,14 +27,19 @@ class LLM:
     mechanism and efficient memory management.
 
     NOTE: This class is intended to be used for offline inference. For online
-    serving, use the `AsyncLLMEngine` class instead.
-    NOTE: For the comprehensive list of arguments, see `EngineArgs`.
+    serving, use the :class:`~aphrodite.AsyncAphrodite` class instead.
+
+    NOTE: For the comprehensive list of arguments, see
+    :class:`~aphrodite.EngineArgs`.
 
     Args:
         model: The name or path of a HuggingFace Transformers model.
         tokenizer: The name or path of a HuggingFace Transformers tokenizer.
         tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
             if available, and "slow" will always use the slow tokenizer.
+        skip_tokenizer_init: If true, skip initialization of tokenizer and
+            detokenizer. Expect valid prompt_token_ids and None for prompt
+            from the input.
         trust_remote_code: Trust remote code (e.g., from HuggingFace) when
             downloading the model and tokenizer.
         tensor_parallel_size: The number of GPUs to use for distributed
@@ -42,12 +50,15 @@ class LLM:
             However, if the `torch_dtype` in the config is `float32`, we will
             use `float16` instead.
         quantization: The method used to quantize the model weights. Currently,
-            we support "awq", "gptq", "quip" and "squeezellm". If None,
-            we first check the `quantization_config` attribute in the model
-            config file. If that is None, we assume the model weights are not
-            quantized and use `dtype` to determine the data type of the weights.
+            we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
+            If None, we first check the `quantization_config` attribute in the
+            model config file. If that is None, we assume the model weights are
+            not quantized and use `dtype` to determine the data type of
+            the weights.
         revision: The specific model version to use. It can be a branch name,
             a tag name, or a commit id.
+        tokenizer_revision: The specific tokenizer version to use. It can be a
+            branch name, a tag name, or a commit id.
         seed: The seed to initialize the random number generator for sampling.
         gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
             reserve for the model weights, activations, and KV cache. Higher
@@ -68,27 +79,40 @@ class LLM:
         max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
             When a sequence has context length larger than this, we fall back
             to eager mode.
-        disable_custom_all_reduce: See ParallelConfig.
+        disable_custom_all_reduce: See ParallelConfig
     """
 
+    DEPRECATE_LEGACY: ClassVar[bool] = False
+    """A flag to toggle whether to deprecate the legacy generate/encode API."""
+
+    @classmethod
+    @contextmanager
+    def deprecate_legacy_api(cls):
+        cls.DEPRECATE_LEGACY = True
+
+        yield
+
+        cls.DEPRECATE_LEGACY = False
+
     def __init__(
         self,
         model: str,
         tokenizer: Optional[str] = None,
         tokenizer_mode: str = "auto",
+        skip_tokenizer_init: bool = False,
         trust_remote_code: bool = False,
         tensor_parallel_size: int = 1,
         dtype: str = "auto",
         quantization: Optional[str] = None,
         revision: Optional[str] = None,
+        tokenizer_revision: Optional[str] = None,
         seed: int = 0,
         gpu_memory_utilization: float = 0.9,
         swap_space: int = 4,
-        enforce_eager: bool = True,
+        enforce_eager: bool = False,
         max_context_len_to_capture: Optional[int] = None,
         max_seq_len_to_capture: int = 8192,
         disable_custom_all_reduce: bool = False,
-        enable_prefix_caching: bool = False,
         **kwargs,
     ) -> None:
         if "disable_log_stats" not in kwargs:
@@ -97,11 +121,13 @@ class LLM:
             model=model,
             tokenizer=tokenizer,
             tokenizer_mode=tokenizer_mode,
+            skip_tokenizer_init=skip_tokenizer_init,
             trust_remote_code=trust_remote_code,
             tensor_parallel_size=tensor_parallel_size,
             dtype=dtype,
             quantization=quantization,
             revision=revision,
+            tokenizer_revision=tokenizer_revision,
             seed=seed,
             gpu_memory_utilization=gpu_memory_utilization,
             swap_space=swap_space,
@@ -109,7 +135,6 @@ class LLM:
             max_context_len_to_capture=max_context_len_to_capture,
             max_seq_len_to_capture=max_seq_len_to_capture,
             disable_custom_all_reduce=disable_custom_all_reduce,
-            enable_prefix_caching=enable_prefix_caching,
             **kwargs,
         )
         self.llm_engine = AphroditeEngine.from_engine_args(engine_args)
@@ -117,23 +142,109 @@ class LLM:
 
     def get_tokenizer(
             self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
-        return self.llm_engine.tokenizer
+        return self.llm_engine.tokenizer.tokenizer
 
     def set_tokenizer(
         self,
         tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
     ) -> None:
-        self.llm_engine.tokenizer = tokenizer
+        self.llm_engine.tokenizer.tokenizer = tokenizer
 
+    @overload  # LEGACY: single (prompt + optional token ids)
     def generate(
         self,
-        prompts: Optional[Union[str, List[str]]] = None,
+        prompts: str,
+        sampling_params: Optional[Union[SamplingParams,
+                                        List[SamplingParams]]] = None,
+        prompt_token_ids: Optional[List[int]] = None,
+        use_tqdm: bool = True,
+        lora_request: Optional[LoRARequest] = None,
+        multi_modal_data: Optional[MultiModalData] = None,
+    ) -> List[RequestOutput]:
+        ...
+
+    @overload  # LEGACY: multi (prompt + optional token ids)
+    def generate(
+        self,
+        prompts: List[str],
         sampling_params: Optional[Union[SamplingParams,
                                         List[SamplingParams]]] = None,
         prompt_token_ids: Optional[List[List[int]]] = None,
         use_tqdm: bool = True,
         lora_request: Optional[LoRARequest] = None,
         multi_modal_data: Optional[MultiModalData] = None,
+    ) -> List[RequestOutput]:
+        ...
+
+    @overload  # LEGACY: single (token ids + optional prompt)
+    def generate(
+        self,
+        prompts: Optional[str] = None,
+        sampling_params: Optional[Union[SamplingParams,
+                                        List[SamplingParams]]] = None,
+        *,
+        prompt_token_ids: List[int],
+        use_tqdm: bool = True,
+        lora_request: Optional[LoRARequest] = None,
+        multi_modal_data: Optional[MultiModalData] = None,
+    ) -> List[RequestOutput]:
+        ...
+
+    @overload  # LEGACY: multi (token ids + optional prompt)
+    def generate(
+        self,
+        prompts: Optional[List[str]] = None,
+        sampling_params: Optional[Union[SamplingParams,
+                                        List[SamplingParams]]] = None,
+        *,
+        prompt_token_ids: List[List[int]],
+        use_tqdm: bool = True,
+        lora_request: Optional[LoRARequest] = None,
+        multi_modal_data: Optional[MultiModalData] = None,
+    ) -> List[RequestOutput]:
+        ...
+
+    @overload  # LEGACY: single or multi token ids [pos-only]
+    def generate(
+        self,
+        prompts: None,
+        sampling_params: None,
+        prompt_token_ids: Union[List[int], List[List[int]]],
+        use_tqdm: bool = True,
+        lora_request: Optional[LoRARequest] = None,
+        multi_modal_data: Optional[MultiModalData] = None,
+    ) -> List[RequestOutput]:
+        ...
+
+    @overload
+    def generate(
+        self,
+        inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
+        /,  # We may enable `inputs` keyword after removing the old API
+        *,
+        sampling_params: Optional[Union[SamplingParams,
+                                        Sequence[SamplingParams]]] = None,
+        use_tqdm: bool = True,
+        lora_request: Optional[LoRARequest] = None,
+    ) -> List[RequestOutput]:
+        ...
+
+    @deprecate_kwargs("prompts",
+                      "prompt_token_ids",
+                      "multi_modal_data",
+                      is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
+                      additional_message="Please use the 'inputs' parameter "
+                      "instead.")
+    def generate(
+        self,
+        prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
+                       Optional[Union[str, List[str]]]] = None,
+        sampling_params: Optional[Union[SamplingParams,
+                                        Sequence[SamplingParams]]] = None,
+        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
+        use_tqdm: bool = True,
+        lora_request: Optional[LoRARequest] = None,
+        multi_modal_data: Optional[MultiModalData] = None,
     ) -> List[RequestOutput]:
         """Generates the completions for the input prompts.
 
@@ -142,170 +253,275 @@ class LLM:
         into a single list and pass it to this method.
 
         Args:
-            prompts: A list of prompts to generate completions for.
+            inputs: A list of inputs to generate completions for.
             sampling_params: The sampling parameters for text generation. If
-                None, we use the default sampling parameters.
-                When it's a single value, it's applied to every prompt.
-                When it's a list, the list must have the same length as
-                the prompts and it's paired one-to-one with the prompts.
-            prompt_token_ids: A list of token IDs for the prompts. If None, we
-                use the tokenizer to convert the prompts to token IDs.
+                None, we use the default sampling parameters. 
+                When it is a single value, it is applied to every prompt. 
+                When it is a list, the list must have the same length as the 
+                prompts and it is paired one by one with the prompt.
             use_tqdm: Whether to use tqdm to display the progress bar.
             lora_request: LoRA request to use for generation, if any.
-            multi_modal_data: Multi-modal data to use for generation, if any.
 
         Returns:
             A list of `RequestOutput` objects containing the
             generated completions in the same order as the input prompts.
         """
+        if prompt_token_ids is not None or multi_modal_data is not None:
+            inputs = self._convert_v1_inputs(
+                prompts=cast(Optional[Union[str, List[str]]], prompts),
+                prompt_token_ids=prompt_token_ids,
+                multi_modal_data=multi_modal_data,
+            )
+        else:
+            inputs = cast(
+                Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
+                prompts)
+
         if sampling_params is None:
             # Use default sampling params.
             sampling_params = SamplingParams()
 
-        requests_data = self._validate_and_prepare_requests(
-            prompts,
-            sampling_params,
-            prompt_token_ids,
-            lora_request,
-            multi_modal_data,
+        self._validate_and_add_requests(
+            inputs=inputs,
+            params=sampling_params,
+            lora_request=lora_request,
         )
 
-        # Add requests to the engine and run the engine
-        for request_data in requests_data:
-            self._add_request(**request_data)
+        outputs = self._run_engine(use_tqdm=use_tqdm)
+        return AphroditeEngine.validate_outputs(outputs, RequestOutput)
 
-        return self._run_engine(use_tqdm)
+    @overload  # LEGACY: single (prompt + optional token ids)
+    def encode(
+        self,
+        prompts: str,
+        pooling_params: Optional[Union[PoolingParams,
+                                       Sequence[PoolingParams]]] = None,
+        prompt_token_ids: Optional[List[int]] = None,
+        use_tqdm: bool = True,
+        lora_request: Optional[LoRARequest] = None,
+        multi_modal_data: Optional[MultiModalData] = None,
+    ) -> List[EmbeddingRequestOutput]:
+        ...
 
+    @overload  # LEGACY: multi (prompt + optional token ids)
     def encode(
         self,
-        prompts: Optional[Union[str, List[str]]] = None,
+        prompts: List[str],
         pooling_params: Optional[Union[PoolingParams,
-                                       List[PoolingParams]]] = None,
+                                       Sequence[PoolingParams]]] = None,
         prompt_token_ids: Optional[List[List[int]]] = None,
         use_tqdm: bool = True,
         lora_request: Optional[LoRARequest] = None,
         multi_modal_data: Optional[MultiModalData] = None,
+    ) -> List[EmbeddingRequestOutput]:
+        ...
+
+    @overload  # LEGACY: single (token ids + optional prompt)
+    def encode(
+        self,
+        prompts: Optional[str] = None,
+        pooling_params: Optional[Union[PoolingParams,
+                                       Sequence[PoolingParams]]] = None,
+        *,
+        prompt_token_ids: List[int],
+        use_tqdm: bool = True,
+        lora_request: Optional[LoRARequest] = None,
+        multi_modal_data: Optional[MultiModalData] = None,
+    ) -> List[EmbeddingRequestOutput]:
+        ...
+
+    @overload  # LEGACY: multi (token ids + optional prompt)
+    def encode(
+        self,
+        prompts: Optional[List[str]] = None,
+        pooling_params: Optional[Union[PoolingParams,
+                                       Sequence[PoolingParams]]] = None,
+        *,
+        prompt_token_ids: List[List[int]],
+        use_tqdm: bool = True,
+        lora_request: Optional[LoRARequest] = None,
+        multi_modal_data: Optional[MultiModalData] = None,
+    ) -> List[EmbeddingRequestOutput]:
+        ...
+
+    @overload  # LEGACY: single or multi token ids [pos-only]
+    def encode(
+        self,
+        prompts: None,
+        pooling_params: None,
+        prompt_token_ids: Union[List[int], List[List[int]]],
+        use_tqdm: bool = True,
+        lora_request: Optional[LoRARequest] = None,
+        multi_modal_data: Optional[MultiModalData] = None,
+    ) -> List[EmbeddingRequestOutput]:
+        ...
+
+    @overload
+    def encode(
+        self,
+        inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
+        /,  # We may enable `inputs` keyword after removing the old API
+        *,
+        pooling_params: Optional[Union[PoolingParams,
+                                       Sequence[PoolingParams]]] = None,
+        use_tqdm: bool = True,
+        lora_request: Optional[LoRARequest] = None,
+    ) -> List[EmbeddingRequestOutput]:
+        ...
+
+    @deprecate_kwargs("prompts",
+                      "prompt_token_ids",
+                      "multi_modal_data",
+                      is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
+                      additional_message="Please use the 'inputs' parameter "
+                      "instead.")
+    def encode(
+        self,
+        prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
+                       Optional[Union[str, List[str]]]] = None,
+        pooling_params: Optional[Union[PoolingParams,
+                                       Sequence[PoolingParams]]] = None,
+        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
+        use_tqdm: bool = True,
+        lora_request: Optional[LoRARequest] = None,
+        multi_modal_data: Optional[MultiModalData] = None,
     ) -> List[EmbeddingRequestOutput]:
         """Generates the completions for the input prompts.
+
         NOTE: This class automatically batches the given prompts, considering
         the memory constraint. For the best performance, put all of your prompts
         into a single list and pass it to this method.
+
         Args:
-            prompts: A list of prompts to generate completions for.
+            inputs: The inputs to the LLM. You may pass a sequence of inputs for
+                batch inference. See
+                :class:`~aphrodite.inputs.PromptStrictInputs`
+                for more details about the format of each input.
             pooling_params: The pooling parameters for pooling. If None, we
                 use the default pooling parameters.
-            prompt_token_ids: A list of token IDs for the prompts. If None, we
-                use the tokenizer to convert the prompts to token IDs.
             use_tqdm: Whether to use tqdm to display the progress bar.
             lora_request: LoRA request to use for generation, if any.
-            multi_modal_data: Multi modal data.
+
         Returns:
             A list of `EmbeddingRequestOutput` objects containing the
             generated embeddings in the same order as the input prompts.
         """
+        if prompt_token_ids is not None or multi_modal_data is not None:
+            inputs = self._convert_v1_inputs(
+                prompts=cast(Optional[Union[str, List[str]]], prompts),
+                prompt_token_ids=prompt_token_ids,
+                multi_modal_data=multi_modal_data,
+            )
+        else:
+            inputs = cast(
+                Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
+                prompts)
+
         if pooling_params is None:
             # Use default pooling params.
             pooling_params = PoolingParams()
 
-        requests_data = self._validate_and_prepare_requests(
-            prompts,
-            pooling_params,
-            prompt_token_ids,
-            lora_request,
-            multi_modal_data,
+        self._validate_and_add_requests(
+            inputs=inputs,
+            params=pooling_params,
+            lora_request=lora_request,
         )
 
-        # Add requests to the engine and run the engine
-        for request_data in requests_data:
-            self._add_request(**request_data)
-
-        return self._run_engine(use_tqdm)
+        outputs = self._run_engine(use_tqdm=use_tqdm)
+        return AphroditeEngine.validate_outputs(outputs,
+                                                EmbeddingRequestOutput)
 
-    def _validate_and_prepare_requests(
+    # LEGACY
+    def _convert_v1_inputs(
         self,
         prompts: Optional[Union[str, List[str]]],
-        params: Union[Union[SamplingParams, PoolingParams],
-                      List[Union[SamplingParams,
-                                 PoolingParams]]],  # Unified parameter
-        prompt_token_ids: Optional[List[List[int]]] = None,
-        lora_request: Optional[LoRARequest] = None,
-        multi_modal_data: Optional[MultiModalData] = None,
-    ) -> List[dict]:
-        """Validates and prepares request data for adding to the engine.
-        Ensures prompts and token IDs are consistent, and returns a list of
-        dictionaries with request data for further processing.
-        """
-        if prompts is None and prompt_token_ids is None:
-            raise ValueError("Either prompts or prompt_token_ids must be "
-                             "provided.")
-        if isinstance(prompts, str):
-            # Convert a single prompt to a list.
-            prompts = [prompts]
-        if prompts is not None and prompt_token_ids is not None and len(
-                prompts) != len(prompt_token_ids):
-            raise ValueError(
-                "The lengths of prompts and prompt_token_ids must "
-                "be the same.")
+        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
+        multi_modal_data: Optional[MultiModalData],
+    ):
+        # skip_tokenizer_init is now checked in engine
+
+        if prompts is not None:
+            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
+        if prompt_token_ids is not None:
+            prompt_token_ids = [
+                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
+            ]
 
+        num_requests = None
         if prompts is not None:
             num_requests = len(prompts)
-        else:
-            assert prompt_token_ids is not None
+        if prompt_token_ids is not None:
+            if (num_requests is not None
+                    and num_requests != len(prompt_token_ids)):
+                raise ValueError("The lengths of prompts and prompt_token_ids "
+                                 "must be the same.")
+
             num_requests = len(prompt_token_ids)
+        if num_requests is None:
+            raise ValueError("Either prompts or prompt_token_ids must be "
+                             "provided.")
+
+        inputs: List[PromptInputs] = []
+        for i in range(num_requests):
+            if prompts is not None:
+                if prompt_token_ids is not None:
+                    item = TextTokensPrompt(
+                        prompt=prompts[i],
+                        prompt_token_ids=prompt_token_ids[i])
+                else:
+                    item = TextPrompt(prompt=prompts[i])
+            else:
+                if prompt_token_ids is not None:
+                    item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
+                else:
+                    raise AssertionError
+
+            if multi_modal_data is not None:
+                item["multi_modal_data"] = multi_modal_data
+
+            inputs.append(item)
+
+        return inputs
+
+    def _validate_and_add_requests(
+        self,
+        inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
+        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
+                      Sequence[PoolingParams]],
+        lora_request: Optional[LoRARequest],
+    ) -> None:
+        if isinstance(inputs, (str, dict)):
+            # Convert a single prompt to a list.
+            inputs = [inputs]
+
+        num_requests = len(inputs)
 
         if isinstance(params, list) and len(params) != num_requests:
             raise ValueError("The lengths of prompts and params "
-                             "be the same.")
-
-        if multi_modal_data:
-            multi_modal_data.data = multi_modal_data.data.to(torch.float16)
+                             "must be the same.")
 
         # Add requests to the engine.
-        requests_data = []
-        num_requests = len(prompts) if prompts is not None else len(
-            prompt_token_ids)
-        for i in range(num_requests):
-            prompt = prompts[i] if prompts is not None else None
-            token_ids = None if prompt_token_ids is None else prompt_token_ids[
-                i]
-            multi_modal_item = MultiModalData(
-                type=multi_modal_data.type,
-                data=multi_modal_data.data[i].unsqueeze(0),
-            ) if multi_modal_data else None
-
-            requests_data.append({
-                "prompt":
-                prompt,
-                "params":
-                params[i] if isinstance(params, list) else params,
-                "prompt_token_ids":
-                token_ids,
-                "lora_request":
-                lora_request,
-                "multi_modal_data":
-                multi_modal_item,
-            })
-
-        return requests_data
+        for i, request_inputs in enumerate(inputs):
+            self._add_request(
+                request_inputs,
+                params[i] if isinstance(params, Sequence) else params,
+                lora_request=lora_request,
+            )
 
     def _add_request(
         self,
-        prompt: Optional[str],
+        inputs: PromptInputs,
         params: Union[SamplingParams, PoolingParams],
-        prompt_token_ids: Optional[List[int]],
         lora_request: Optional[LoRARequest] = None,
-        multi_modal_data: Optional[MultiModalData] = None,
     ) -> None:
         request_id = str(next(self.request_counter))
         self.llm_engine.add_request(request_id,
-                                    prompt,
+                                    inputs,
                                     params,
-                                    prompt_token_ids,
-                                    lora_request=lora_request,
-                                    multi_modal_data=multi_modal_data)
+                                    lora_request=lora_request)
 
     def _run_engine(
-            self, use_tqdm: bool
+            self, *, use_tqdm: bool
     ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
         # Initialize tqdm.
         if use_tqdm:
@@ -337,5 +553,4 @@ class LLM:
         # Sort the outputs by request ID.
         # This is necessary because some requests may be finished earlier than
         # its previous requests.
-        outputs = sorted(outputs, key=lambda x: int(x.request_id))
-        return outputs
+        return sorted(outputs, key=lambda x: int(x.request_id))

+ 9 - 3
aphrodite/endpoints/openai/serving_chat.py

@@ -92,9 +92,15 @@ class OpenAIServingChat(OpenAIServing):
         except ValueError as e:
             return self.create_error_response(str(e))
 
-        result_generator = self.engine.generate(prompt_text, sampling_params,
-                                                request_id, prompt_ids,
-                                                lora_request)
+        result_generator = self.engine.generate(
+            {
+                "prompt": prompt_text,
+                "prompt_token_ids": prompt_ids
+            },
+            sampling_params,
+            request_id,
+            lora_request,
+        )
         # Streaming response
         if request.stream:
             return self.chat_completion_stream_generator(

+ 11 - 6
aphrodite/endpoints/openai/serving_completions.py

@@ -123,12 +123,17 @@ class OpenAIServingCompletion(OpenAIServing):
                         truncate_prompt_tokens)
                 prompt_ids, prompt_text = prompt_formats
 
-                generators.append(
-                    self.engine.generate(prompt_text,
-                                         sampling_params,
-                                         f"{request_id}-{i}",
-                                         prompt_token_ids=prompt_ids,
-                                         lora_request=lora_request))
+                generator = self.engine.generate(
+                    {
+                        "prompt": prompt_text,
+                        "prompt_token_ids": prompt_ids
+                    },
+                    sampling_params,
+                    f"{request_id}-{i}",
+                    lora_request=lora_request,
+                )
+
+                generators.append(generator)
         except ValueError as e:
             # TODO: Use a specific-specific Validation Error
             return self.create_error_response(str(e))

+ 25 - 15
aphrodite/endpoints/openai/serving_embedding.py

@@ -96,11 +96,16 @@ class OpenAIServingEmbedding(OpenAIServing):
 
                 prompt_ids, prompt_text = prompt_formats
 
-                generators.append(
-                    self.engine.generate(prompt_text,
-                                         pooling_params,
-                                         f"{request_id}-{i}",
-                                         prompt_token_ids=prompt_ids))
+                generator = self.engine.encode(
+                    {
+                        "prompt": prompt_text,
+                        "prompt_token_ids": prompt_ids
+                    },
+                    pooling_params,
+                    f"{request_id}-{i}",
+                )
+
+                generators.append(generator)
         except ValueError as e:
             # TODO: Use a aphrodite-specific Validation Error
             return self.create_error_response(str(e))
@@ -109,15 +114,20 @@ class OpenAIServingEmbedding(OpenAIServing):
             int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
 
         # Non-streaming response
-        final_res_batch: EmbeddingRequestOutput = [None] * len(prompts)
-        async for i, res in result_generator:
-            if await raw_request.is_disconnected():
-                # Abort the request if the client disconnects.
-                await self.engine.abort(f"{request_id}-{i}")
-                # TODO: Use a aphrodite-specific Validation Error
-                return self.create_error_response("Client disconnected")
-            final_res_batch[i] = res
-        response = request_output_to_embedding_response(
-            final_res_batch, request_id, created_time, model_name)
+        final_res_batch: List[Optional[EmbeddingRequestOutput]]
+        final_res_batch = [None] * len(prompts)
+        try:
+            async for i, res in result_generator:
+                if await raw_request.is_disconnected():
+                    # Abort the request if the client disconnects.
+                    await self.engine.abort(f"{request_id}-{i}")
+                    # TODO: Use a aphrodite-specific Validation Error
+                    return self.create_error_response("Client disconnected")
+                final_res_batch[i] = res
+            response = request_output_to_embedding_response(
+                final_res_batch, request_id, created_time, model_name)
+        except ValueError as e:
+            # TODO: Use a aphrodite-specific Validation Error
+            return self.create_error_response(str(e))
 
         return response

+ 8 - 2
aphrodite/endpoints/openai/serving_engine.py

@@ -172,7 +172,10 @@ class OpenAIServing:
         })
         return json_str
 
-    async def _check_model(self, request) -> Optional[ErrorResponse]:
+    async def _check_model(
+        self, request: Union[CompletionRequest, ChatCompletionRequest,
+                             EmbeddingRequest]
+    ) -> Optional[ErrorResponse]:
         if request.model in self.served_model_names:
             return
         if request.model in [lora.lora_name for lora in self.lora_requests]:
@@ -200,7 +203,10 @@ class OpenAIServing:
             lora for lora in self.lora_requests if lora.lora_name != lora_name
         ]
 
-    def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
+    def _maybe_get_lora(
+        self, request: Union[CompletionRequest, ChatCompletionRequest,
+                             EmbeddingRequest]
+    ) -> Optional[LoRARequest]:
         if request.model in self.served_model_names:
             return
         for lora in self.lora_requests:

+ 170 - 82
aphrodite/engine/aphrodite_engine.py

@@ -1,5 +1,8 @@
 import time
-from typing import Iterable, List, Optional, Type, Union
+from contextlib import contextmanager
+from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional
+from typing import Sequence as GenericSequence
+from typing import Type, TypeVar, Union
 
 from loguru import logger
 from transformers import GenerationConfig, PreTrainedTokenizer
@@ -9,16 +12,16 @@ from aphrodite.common.config import (CacheConfig, DecodingConfig, DeviceConfig,
                                      LoadConfig, LoRAConfig, ModelConfig,
                                      ParallelConfig, SchedulerConfig,
                                      SpeculativeConfig, VisionLanguageConfig)
+from aphrodite.common.inputs import LLMInputs, PromptInputs
 from aphrodite.common.logger import setup_logger
 from aphrodite.common.outputs import (EmbeddingRequestOutput, RequestOutput,
                                       RequestOutputFactory)
 from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import (EmbeddingSequenceGroupOutput,
-                                       ExecuteModelRequest, MultiModalData,
-                                       PoolerOutput, SamplerOutput, Sequence,
-                                       SequenceGroup, SequenceGroupMetadata,
-                                       SequenceStatus)
+                                       ExecuteModelRequest, PoolerOutput,
+                                       SamplerOutput, Sequence, SequenceGroup,
+                                       SequenceGroupMetadata, SequenceStatus)
 from aphrodite.common.utils import Counter
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.metrics import StatLogger, Stats
@@ -50,6 +53,9 @@ def _load_generation_config_dict(model_config: ModelConfig):
         return {}
 
 
+_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
+
+
 class AphroditeEngine:
     """An LLM engine that receives requests and generates texts.
 
@@ -83,6 +89,57 @@ class AphroditeEngine:
         log_stats: Whether to log statistics.
     """
 
+    DO_VALIDATE_OUTPUT: ClassVar[bool] = False
+    """A flag to toggle whether to validate the type of request output."""
+
+    @classmethod
+    @contextmanager
+    def enable_output_validation(cls):
+        cls.DO_VALIDATE_OUTPUT = True
+
+        yield
+
+        cls.DO_VALIDATE_OUTPUT = False
+
+    @classmethod
+    def validate_output(
+        cls,
+        output: object,
+        output_type: Type[_O],
+    ) -> _O:
+        do_validate = cls.DO_VALIDATE_OUTPUT
+
+        if ((TYPE_CHECKING or do_validate)
+                and not isinstance(output, output_type)):
+            raise TypeError(f"Expected output of type {output_type}, "
+                            f"but found type {type(output)}")
+
+        return output
+
+    @classmethod
+    def validate_outputs(
+        cls,
+        outputs: GenericSequence[object],
+        output_type: Type[_O],
+    ) -> List[_O]:
+        do_validate = cls.DO_VALIDATE_OUTPUT
+
+        outputs_: List[_O]
+        if TYPE_CHECKING or do_validate:
+            outputs_ = []
+            for output in outputs:
+                if not isinstance(output, output_type):
+                    raise TypeError(f"Expected output of type {output_type}, "
+                                    f"but found type {type(output)}")
+
+                outputs_.append(output)
+        else:
+            outputs_ = outputs
+
+        return outputs_
+
+    tokenizer: Optional[BaseTokenizerGroup]
+
     def __init__(
         self,
         model_config: ModelConfig,
@@ -133,12 +190,11 @@ class AphroditeEngine:
         self.log_stats = log_stats
 
         if not self.model_config.skip_tokenizer_init:
-            self.tokenizer: BaseTokenizerGroup
-            self._init_tokenizer()
+            self.tokenizer = self._init_tokenizer()
             self.detokenizer = Detokenizer(self.tokenizer)
         else:
-            self.detokenizer = None
             self.tokenizer = None
+            self.detokenizer = None
 
         self.seq_counter = Counter()
         self.generation_config_fields = _load_generation_config_dict(
@@ -235,8 +291,8 @@ class AphroditeEngine:
             from aphrodite.executor.ray_gpu_executor import RayGPUExecutor
             executor_class = RayGPUExecutor
         elif distributed_executor_backend == "mp":
-            from aphrodite.executor.multiproc_gpu_executor import (
-                MultiprocessingGPUExecutor)
+            from aphrodite.executor.multiproc_gpu_executor import \
+                MultiprocessingGPUExecutor
             executor_class = MultiprocessingGPUExecutor
         else:
             from aphrodite.executor.gpu_executor import GPUExecutor
@@ -261,14 +317,26 @@ class AphroditeEngine:
         if model_executor := getattr(self, "model_executor", None):
             model_executor.shutdown()
 
+    MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
+                                   "skip_tokenizer_init is True")
+
+    def get_tokenizer_group(
+            self,
+            fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
+        if self.tokenizer is None:
+            raise ValueError(fail_msg)
+
+        return self.tokenizer
+
     def get_tokenizer(self) -> "PreTrainedTokenizer":
-        return self.tokenizer.get_lora_tokenizer(None)
+        return self.get_tokenizer_group().get_lora_tokenizer(None)
 
     def get_tokenizer_for_seq(self,
                               sequence: Sequence) -> "PreTrainedTokenizer":
-        return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
+        return self.get_tokenizer_group().get_lora_tokenizer(
+            sequence.lora_request)
 
-    def _init_tokenizer(self, **tokenizer_init_kwargs):
+    def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
         init_kwargs = dict(
             tokenizer_id=self.model_config.tokenizer,
             enable_lora=bool(self.lora_config),
@@ -278,8 +346,8 @@ class AphroditeEngine:
             trust_remote_code=self.model_config.trust_remote_code,
             revision=self.model_config.tokenizer_revision)
         init_kwargs.update(tokenizer_init_kwargs)
-        self.tokenizer = get_tokenizer_group(
-            self.parallel_config.tokenizer_pool_config, **init_kwargs)
+        return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
+                                   **init_kwargs)
 
     def _verify_args(self) -> None:
         self.model_config.verify_with_parallel_config(self.parallel_config)
@@ -289,29 +357,85 @@ class AphroditeEngine:
             self.lora_config.verify_with_scheduler_config(
                 self.scheduler_config)
 
-    def encode_request(
+    def _get_eos_token_id(
+            self, lora_request: Optional[LoRARequest]) -> Optional[int]:
+        if self.tokenizer is None:
+            logger.warning("Using None for EOS token id because tokenizer "
+                           "is not initialized")
+            return None
+
+        return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
+
+    def _add_processed_request(
+        self,
+        request_id: str,
+        processed_inputs: LLMInputs,
+        params: Union[SamplingParams, PoolingParams],
+        arrival_time: float,
+        lora_request: Optional[LoRARequest],
+    ) -> None:
+        # Create the sequences.
+        block_size = self.cache_config.block_size
+        seq_id = next(self.seq_counter)
+        eos_token_id = self._get_eos_token_id(lora_request)
+
+        seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
+                       lora_request)
+
+        # Create a SequenceGroup based on SamplingParams or PoolingParams
+        if isinstance(params, SamplingParams):
+            seq_group = self._create_sequence_group_with_sampling(
+                request_id,
+                seq,
+                params,
+                arrival_time=arrival_time,
+                lora_request=lora_request,
+            )
+        elif isinstance(params, PoolingParams):
+            seq_group = self._create_sequence_group_with_pooling(
+                request_id,
+                seq,
+                params,
+                arrival_time=arrival_time,
+                lora_request=lora_request,
+            )
+        else:
+            raise ValueError(
+                "Either SamplingParams or PoolingParams must be provided.")
+
+        # Add the sequence group to the scheduler.
+        self.scheduler.add_seq_group(seq_group)
+
+    def process_model_inputs(
         self,
-        request_id: str,  # pylint: disable=unused-argument
-        prompt: Optional[str],
-        prompt_token_ids: Optional[List[int]] = None,
+        request_id: str,
+        inputs: PromptInputs,
         lora_request: Optional[LoRARequest] = None,
-    ):
-        if prompt_token_ids is None:
-            assert prompt is not None
-            prompt_token_ids = self.tokenizer.encode(request_id=request_id,
-                                                     prompt=prompt,
-                                                     lora_request=lora_request)
-        return prompt_token_ids
+    ) -> LLMInputs:
+        if isinstance(inputs, str):
+            inputs = {"prompt": inputs}
+
+        if "prompt_token_ids" not in inputs:
+            tokenizer = self.get_tokenizer_group("prompts must be None if "
+                                                 "skip_tokenizer_init is True")
+
+            prompt_token_ids = tokenizer.encode(request_id=request_id,
+                                                prompt=inputs["prompt"],
+                                                lora_request=lora_request)
+        else:
+            prompt_token_ids = inputs["prompt_token_ids"]
+
+        return LLMInputs(prompt_token_ids=prompt_token_ids,
+                         prompt=inputs.get("prompt"),
+                         multi_modal_data=inputs.get("multi_modal_data"))
 
     def add_request(
         self,
         request_id: str,
-        prompt: Optional[str],
+        inputs: PromptInputs,
         params: Union[SamplingParams, PoolingParams],
-        prompt_token_ids: Optional[List[int]] = None,
         arrival_time: Optional[float] = None,
         lora_request: Optional[LoRARequest] = None,
-        multi_modal_data: Optional[MultiModalData] = None,
     ) -> None:
         """Add a request to the engine's request pool.
 
@@ -362,59 +486,26 @@ class AphroditeEngine:
                              "not enabled!")
         if arrival_time is None:
             arrival_time = time.time()
-        prompt_token_ids = self.encode_request(
-            request_id=request_id,
-            prompt=prompt,
-            prompt_token_ids=prompt_token_ids,
-            lora_request=lora_request)
-
-        # Create the sequences.
-        block_size = self.cache_config.block_size
-        seq_id = next(self.seq_counter)
-        eos_token_id = None
-        if self.tokenizer:
-            eos_token_id = self.tokenizer.get_lora_tokenizer(
-                lora_request).eos_token_id
-        else:
-            logger.warning("Use None for EOS token id because tokenizer is "
-                           "not initialized")
-        seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
-                       eos_token_id, lora_request)
 
-        # Create a SequenceGroup based on SamplingParams or PoolingParams
-        if isinstance(params, SamplingParams):
-            seq_group = self._create_sequence_group_with_sampling(
-                request_id,
-                seq,
-                params,
-                arrival_time,
-                lora_request,
-                multi_modal_data,
-            )
-        elif isinstance(params, PoolingParams):
-            seq_group = self._create_sequence_group_with_pooling(
-                request_id,
-                seq,
-                params,
-                arrival_time,
-                lora_request,
-                multi_modal_data,
-            )
-        else:
-            raise ValueError(
-                "Either SamplingParams or PoolingParams must be provided.")
+        processed_inputs = self.process_model_inputs(request_id=request_id,
+                                                     inputs=inputs,
+                                                     lora_request=lora_request)
 
-        # Add the sequence group to the scheduler.
-        self.scheduler.add_seq_group(seq_group)
+        self._add_processed_request(
+            request_id=request_id,
+            processed_inputs=processed_inputs,
+            params=params,
+            arrival_time=arrival_time,
+            lora_request=lora_request,
+        )
 
     def _create_sequence_group_with_sampling(
         self,
         request_id: str,
         seq: Sequence,
         sampling_params: SamplingParams,
-        arrival_time: Optional[float] = None,
-        lora_request: Optional[LoRARequest] = None,
-        multi_modal_data: Optional[MultiModalData] = None,
+        arrival_time: float,
+        lora_request: Optional[LoRARequest],
     ) -> SequenceGroup:
         """Creates a SequenceGroup with SamplingParams."""
         max_logprobs = self.get_model_config().max_logprobs
@@ -440,8 +531,7 @@ class AphroditeEngine:
                                   seqs=[seq],
                                   arrival_time=arrival_time,
                                   sampling_params=sampling_params,
-                                  lora_request=lora_request,
-                                  multi_modal_data=multi_modal_data)
+                                  lora_request=lora_request)
 
         return seq_group
 
@@ -450,9 +540,8 @@ class AphroditeEngine:
         request_id: str,
         seq: Sequence,
         pooling_params: PoolingParams,
-        arrival_time: Optional[float] = None,
-        lora_request: Optional[LoRARequest] = None,
-        multi_modal_data: Optional[MultiModalData] = None,
+        arrival_time: float,
+        lora_request: Optional[LoRARequest],
     ) -> SequenceGroup:
         """Creates a SequenceGroup with PoolingParams."""
         # Defensive copy of PoolingParams, which are used by the pooler
@@ -462,7 +551,6 @@ class AphroditeEngine:
                                   seqs=[seq],
                                   arrival_time=arrival_time,
                                   lora_request=lora_request,
-                                  multi_modal_data=multi_modal_data,
                                   pooling_params=pooling_params)
         return seq_group
 
@@ -515,7 +603,7 @@ class AphroditeEngine:
 
     def _process_model_outputs(
         self,
-        output: List[Union[SamplerOutput, PoolerOutput]],
+        output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
         scheduled_seq_groups: List[ScheduledSequenceGroup],
         ignored_seq_groups: List[SequenceGroup],
         seq_group_metadata_list: List[SequenceGroupMetadata],
@@ -530,7 +618,7 @@ class AphroditeEngine:
         # Organize outputs by [sequence group][step] instead of
         # [step][sequence group].
         output_by_sequence_group = create_output_by_sequence_group(
-            sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
+            output, num_seq_groups=len(scheduled_seq_groups))
 
         # Update the scheduled sequence groups with the model outputs.
         for scheduled_seq_group, outputs, seq_group_meta in zip(

+ 69 - 74
aphrodite/engine/async_aphrodite.py

@@ -9,11 +9,11 @@ from loguru import logger
 from transformers import PreTrainedTokenizer
 
 from aphrodite.common.config import DecodingConfig, ModelConfig
+from aphrodite.common.inputs import LLMInputs, PromptInputs
 from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
 from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.sampling_params import SamplingParams
-from aphrodite.common.sequence import (ExecuteModelRequest, MultiModalData,
-                                       SamplerOutput)
+from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
 from aphrodite.engine.aphrodite_engine import AphroditeEngine
 from aphrodite.engine.args_tools import AsyncEngineArgs
 from aphrodite.executor.ray_utils import initialize_ray_cluster, ray
@@ -248,49 +248,54 @@ class _AsyncAphrodite(AphroditeEngine):
 
         return request_outputs
 
-    async def encode_request_async(
+    async def process_model_inputs_async(
         self,
-        request_id: str,  # pylint: disable=unused-argument
-        prompt: Optional[str],
-        prompt_token_ids: Optional[List[int]] = None,
+        request_id: str,
+        inputs: PromptInputs,
         lora_request: Optional[LoRARequest] = None,
-    ):
-        if prompt_token_ids is None:
-            assert prompt is not None
-            prompt_token_ids = await self.tokenizer.encode_async(
+    ) -> LLMInputs:
+        if isinstance(inputs, str):
+            inputs = {"prompt": inputs}
+
+        if "prompt_token_ids" not in inputs:
+            tokenizer = self.get_tokenizer_group("prompts must be None if "
+                                                 "skip_tokenizer_init is True")
+
+            prompt_token_ids = await tokenizer.encode_async(
                 request_id=request_id,
-                prompt=prompt,
+                prompt=inputs["prompt"],
                 lora_request=lora_request)
-        return prompt_token_ids
+        else:
+            prompt_token_ids = inputs["prompt_token_ids"]
+
+        return LLMInputs(prompt_token_ids=prompt_token_ids,
+                         prompt=inputs.get("prompt"),
+                         multi_modal_data=inputs.get("multi_modal_data"))
 
     async def add_request_async(
         self,
         request_id: str,
-        prompt: Optional[str],
+        inputs: PromptInputs,
         params: Union[SamplingParams, PoolingParams],
-        prompt_token_ids: Optional[List[int]] = None,
         arrival_time: Optional[float] = None,
         lora_request: Optional[LoRARequest] = None,
-        multi_modal_data: Optional[MultiModalData] = None,
     ) -> None:
         if lora_request is not None and not self.lora_config:
             raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                              "not enabled!")
         if arrival_time is None:
             arrival_time = time.time()
-        prompt_token_ids = await self.encode_request_async(
+
+        processed_inputs = await self.process_model_inputs_async(
+            request_id=request_id, inputs=inputs, lora_request=lora_request)
+
+        self._add_processed_request(
             request_id=request_id,
-            prompt=prompt,
-            prompt_token_ids=prompt_token_ids,
-            lora_request=lora_request)
-
-        return self.add_request(request_id,
-                                prompt=prompt,
-                                params=params,
-                                prompt_token_ids=prompt_token_ids,
-                                arrival_time=arrival_time,
-                                lora_request=lora_request,
-                                multi_modal_data=multi_modal_data)
+            processed_inputs=processed_inputs,
+            params=params,
+            arrival_time=arrival_time,
+            lora_request=lora_request,
+        )
 
     async def check_health_async(self) -> None:
         self.model_executor.check_health()
@@ -376,8 +381,8 @@ class AsyncAphrodite:
             from aphrodite.executor.ray_gpu_executor import RayGPUExecutorAsync
             executor_class = RayGPUExecutorAsync
         elif distributed_executor_backend == "mp":
-            from aphrodite.executor.multiproc_gpu_executor import (
-                MultiprocessingGPUExecutorAsync)
+            from aphrodite.executor.multiproc_gpu_executor import \
+                MultiprocessingGPUExecutorAsync
             executor_class = MultiprocessingGPUExecutorAsync
         else:
             from aphrodite.executor.gpu_executor import GPUExecutorAsync
@@ -529,22 +534,26 @@ class AsyncAphrodite:
     async def add_request(
         self,
         request_id: str,
-        prompt: Optional[str],
+        inputs: PromptInputs,
         params: Union[SamplingParams, PoolingParams],
-        prompt_token_ids: Optional[List[int]] = None,
         arrival_time: Optional[float] = None,
         lora_request: Optional[LoRARequest] = None,
-        multi_modal_data: Optional[MultiModalData] = None,
     ) -> AsyncStream:
         if self.log_requests:
-            shortened_prompt = prompt
-            shortened_token_ids = prompt_token_ids
-            if self.max_log_len is not None:
+            if isinstance(inputs, str):
+                shortened_prompt = inputs
+                shortened_token_ids = None
+            else:
+                shortened_prompt = inputs.get("prompt")
+                shortened_token_ids = inputs.get("prompt_token_ids")
+
+            max_log_len = self.max_log_len
+            if max_log_len is not None:
                 if shortened_prompt is not None:
-                    shortened_prompt = shortened_prompt[:self.max_log_len]
+                    shortened_prompt = shortened_prompt[:max_log_len]
                 if shortened_token_ids is not None:
-                    shortened_token_ids = shortened_token_ids[:self.
-                                                              max_log_len]
+                    shortened_token_ids = shortened_token_ids[:max_log_len]
+
             logger.info(f"Received request {request_id}: "
                         f"prompt: {shortened_prompt!r}, "
                         f"params: {params}, "
@@ -565,39 +574,33 @@ class AsyncAphrodite:
             arrival_time = time.time()
 
         if self.engine_use_ray:
-            prompt_token_ids = await (
-                self.engine.encode_request_async.remote(  # type: ignore
+            processed_inputs = await self.engine.process_model_inputs_async \
+                .remote(  # type: ignore
                     request_id=request_id,
-                    prompt=prompt,
-                    prompt_token_ids=prompt_token_ids,
-                    lora_request=lora_request))
+                    inputs=inputs,
+                    lora_request=lora_request)
         else:
-            prompt_token_ids = await self.engine.encode_request_async(
+            processed_inputs = await self.engine.process_model_inputs_async(
                 request_id=request_id,
-                prompt=prompt,
-                prompt_token_ids=prompt_token_ids,
+                inputs=inputs,
                 lora_request=lora_request)
 
         stream = self._request_tracker.add_request(
             request_id,
-            prompt=prompt,
+            inputs=processed_inputs,
             params=params,
-            prompt_token_ids=prompt_token_ids,
             arrival_time=arrival_time,
             lora_request=lora_request,
-            multi_modal_data=multi_modal_data,
         )
 
         return stream
 
     async def generate(
         self,
-        prompt: Optional[str],
+        inputs: PromptInputs,
         sampling_params: SamplingParams,
         request_id: str,
-        prompt_token_ids: Optional[List[int]] = None,
         lora_request: Optional[LoRARequest] = None,
-        multi_modal_data: Optional[MultiModalData] = None
     ) -> AsyncIterator[RequestOutput]:
         """Generate outputs for a request.
 
@@ -663,24 +666,20 @@ class AsyncAphrodite:
             >>> # Process and return the final output
             >>> ...
         """
-        async for output in self.process_request(
+        async for output in self._process_request(
                 request_id,
-                prompt,
+                inputs,
                 sampling_params,
-                prompt_token_ids,
-                lora_request,
-                multi_modal_data,
+                lora_request=lora_request,
         ):
-            yield output
+            yield AphroditeEngine.validate_output(output, RequestOutput)
 
     async def encode(
         self,
-        prompt: Optional[str],
+        inputs: PromptInputs,
         pooling_params: PoolingParams,
         request_id: str,
-        prompt_token_ids: Optional[List[int]] = None,
         lora_request: Optional[LoRARequest] = None,
-        multi_modal_data: Optional[MultiModalData] = None
     ) -> AsyncIterator[EmbeddingRequestOutput]:
         """Generate outputs for a request from an embedding model.
         Generate outputs for a request. This method is a coroutine. It adds the
@@ -735,24 +734,22 @@ class AsyncAphrodite:
             >>> # Process and return the final output
             >>> ...
         """
-        async for output in self.process_request(
+        async for output in self._process_request(
                 request_id,
-                prompt,
+                inputs,
                 pooling_params,
-                prompt_token_ids,
-                lora_request,
-                multi_modal_data,
+                lora_request=lora_request,
         ):
-            yield output
+            yield AphroditeEngine.validate_output(output,
+                                                  EmbeddingRequestOutput)
 
-    async def process_request(
+    async def _process_request(
         self,
         request_id: str,
-        prompt: Optional[str],
+        inputs: PromptInputs,
         params: Union[SamplingParams, PoolingParams],
-        prompt_token_ids: Optional[List[int]] = None,
+        *,
         lora_request: Optional[LoRARequest] = None,
-        multi_modal_data: Optional[MultiModalData] = None,
     ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
         """Common logic to process requests with SamplingParams or
         PoolingParams."""
@@ -760,12 +757,10 @@ class AsyncAphrodite:
 
         stream = await self.add_request(
             request_id,
-            prompt,
+            inputs,
             params,
-            prompt_token_ids=prompt_token_ids,
             arrival_time=arrival_time,
             lora_request=lora_request,
-            multi_modal_data=multi_modal_data,
         )
 
         try:

+ 7 - 4
aphrodite/engine/output_processor/util.py

@@ -1,18 +1,21 @@
 from typing import List
+from typing import Sequence as GenericSequence
+from typing import Union
 
-from aphrodite.common.sequence import SamplerOutput, SequenceGroupOutput
+from aphrodite.common.sequence import (PoolerOutput, SamplerOutput,
+                                       SequenceGroupOutput)
 
 
 def create_output_by_sequence_group(
-        sampler_outputs: List[SamplerOutput],
+        outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]],
         num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
     """Helper method which transforms a 2d list organized by
     [step][sequence group] into [sequence group][step].
     """
-    output_by_sequence_group: List[List[SamplerOutput]] = [
+    output_by_sequence_group: List[List[SequenceGroupOutput]] = [
         [] for _ in range(num_seq_groups)
     ]
-    for step in sampler_outputs:
+    for step in outputs:
         for i, sequence_group_output in enumerate(step):
             output_by_sequence_group[i].append(sequence_group_output)