Browse Source

feat: allow serving encoder-decoder models in the API server (#664)

* wip

* wip

* wip

* wip

* wip

* wip

* finish up

* update example
AlpinDale 6 months ago
parent
commit
62111fab17

+ 10 - 0
aphrodite/common/config.py

@@ -507,6 +507,16 @@ class ModelConfig:
             if t != "attention"
         ])
 
+    @property
+    def is_encoder_decoder_model(self) -> bool:
+        """Extract the HF encoder/decoder model flag."""
+        return getattr(self.hf_config, "is_encoder_decoder", False)
+
+    @property
+    def is_embedding_model(self) -> bool:
+        """Extract the embedding model flag."""
+        return self.embedding_mode
+
 
 class CacheConfig:
     """Configuration for the KV cache.

+ 1 - 1
aphrodite/common/sequence.py

@@ -12,7 +12,7 @@ import torch
 
 from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.sampling_params import SamplingParams
-from aphrodite.inputs import is_valid_encoder_decoder_llm_inputs
+from aphrodite.inputs.parse import is_valid_encoder_decoder_llm_inputs
 from aphrodite.lora.request import LoRARequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 

+ 21 - 52
aphrodite/common/utils.py

@@ -17,8 +17,8 @@ from collections import defaultdict
 from functools import lru_cache, partial, wraps
 from platform import uname
 from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
-                    Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
-                    Union, overload)
+                    Hashable, List, Literal, Optional, OrderedDict, Set, Tuple,
+                    Type, TypeVar, Union, overload)
 from uuid import uuid4
 
 import numpy as np
@@ -27,12 +27,10 @@ import psutil
 import torch
 import torch.types
 from loguru import logger
-from typing_extensions import ParamSpec
+from typing_extensions import ParamSpec, TypeIs, assert_never
 
 from aphrodite import _custom_ops as ops
 from aphrodite.common.logger import enable_trace_function_call
-from aphrodite.inputs import (ExplicitEncoderDecoderPrompt, PromptInputs,
-                              SingletonPromptInputs)
 
 # Exception strings for non-implemented encoder/decoder scenarios
 
@@ -811,6 +809,24 @@ def get_dtype_size(dtype: torch.dtype) -> int:
     return torch.tensor([], dtype=dtype).element_size()
 
 
+# `collections` helpers
+def is_list_of(
+    value: object,
+    typ: Type[T],
+    *,
+    check: Literal["first", "all"] = "first",
+) -> TypeIs[List[T]]:
+    if not isinstance(value, list):
+        return False
+
+    if check == "first":
+        return len(value) == 0 or isinstance(value[0], typ)
+    elif check == "all":
+        return all(isinstance(v, typ) for v in value)
+
+    assert_never(check)
+
+
 def merge_dicts(dict1: Dict[K, List[T]],
                 dict2: Dict[K, List[T]]) -> Dict[K, List[T]]:
     """Merge 2 dicts that have key -> List of items.
@@ -1075,50 +1091,3 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
     """Utility function to run async task in a lock"""
     async with lock:
         return await task(*args, **kwargs)
-
-
-def is_encoder_decoder_model_config(model_config) -> bool:
-    '''
-    Extract the HF encoder/decoder model flag from the ModelConfig instance.
-    Return False if model_config is None.
-    '''
-    return model_config is not None and \
-                getattr(model_config.hf_config,
-                        "is_encoder_decoder",
-                        False)
-
-
-def is_embedding_model_config(model_config) -> bool:
-    '''
-    Extract the embedding model flag from the ModelConfig instance.
-    Return False if model_config is None.
-    '''
-    return model_config is not None and \
-                model_config.embedding_mode
-
-
-def build_explicit_enc_dec_prompt(
-    encoder_prompt: SingletonPromptInputs,
-    decoder_prompt: SingletonPromptInputs,
-) -> ExplicitEncoderDecoderPrompt:
-    return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
-                                        decoder_prompt=decoder_prompt)
-
-
-def zip_enc_dec_prompt_lists(
-    enc_prompt_list: List[SingletonPromptInputs],
-    dec_prompt_list: List[SingletonPromptInputs],
-) -> List[ExplicitEncoderDecoderPrompt]:
-    return [
-        build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
-        for (encoder_prompt,
-             decoder_prompt) in zip(enc_prompt_list, dec_prompt_list)
-    ]
-
-
-def to_enc_dec_tuple_list(
-    enc_dec_prompts: List[ExplicitEncoderDecoderPrompt],
-) -> List[Tuple[PromptInputs, PromptInputs]]:
-    return [(enc_dec_prompt['encoder_prompt'],
-             enc_dec_prompt['decoder_prompt'])
-            for enc_dec_prompt in enc_dec_prompts]

+ 2 - 3
aphrodite/endpoints/chat_utils.py

@@ -3,8 +3,7 @@ import tempfile
 from dataclasses import dataclass
 from functools import lru_cache
 from pathlib import Path
-from typing import (Any, Awaitable, Iterable, List, Optional, Tuple, Union,
-                    cast, final)
+from typing import Any, Awaitable, Iterable, List, Optional, Tuple, Union, cast
 
 import requests
 from loguru import logger
@@ -58,7 +57,7 @@ ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
                                    CustomChatCompletionMessageParam]
 
 
-@final  # So that it should be compatible with Dict[str, str]
+# TODO: Make fields ReadOnly once mypy supports it
 class ConversationMessage(TypedDict):
     role: str
     content: str

+ 2 - 2
aphrodite/endpoints/llm.py

@@ -10,8 +10,8 @@ from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.utils import Counter, deprecate_kwargs
 from aphrodite.engine.aphrodite_engine import AphroditeEngine
 from aphrodite.engine.args_tools import EngineArgs
-from aphrodite.inputs import (PromptInputs, TextPrompt, TokensPrompt,
-                              parse_and_batch_prompt)
+from aphrodite.inputs import PromptInputs, TextPrompt, TokensPrompt
+from aphrodite.inputs.parse import parse_and_batch_prompt
 from aphrodite.lora.request import LoRARequest
 from aphrodite.modeling.guided_decoding import (
     GuidedDecodingRequest, get_local_guided_decoding_logits_processor)

+ 5 - 3
aphrodite/endpoints/openai/logits_processors.py

@@ -40,9 +40,11 @@ def _get_allowed_token_ids_logits_processor(
     return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
 
 
-def logit_bias_logits_processor(logit_bias: Dict[str,
-                                                 float], token_ids: List[int],
-                                logits: torch.Tensor) -> torch.Tensor:
+def logit_bias_logits_processor(
+    logit_bias: Dict[int, float],
+    token_ids: List[int],
+    logits: torch.Tensor,
+) -> torch.Tensor:
     for token_id, bias in logit_bias.items():
         logits[token_id] += bias
     return logits

+ 1 - 1
aphrodite/endpoints/openai/serving_engine.py

@@ -27,7 +27,7 @@ from aphrodite.endpoints.openai.protocol import (ChatCompletionRequest,
                                                  TokenizeRequest)
 # yapf: enable
 from aphrodite.engine.protocol import AsyncEngineClient
-from aphrodite.inputs import parse_and_batch_prompt
+from aphrodite.inputs.parse import parse_and_batch_prompt
 from aphrodite.lora.request import LoRARequest
 from aphrodite.modeling.guided_decoding import (
     get_guided_decoding_logits_processor)

+ 149 - 168
aphrodite/engine/aphrodite_engine.py

@@ -7,6 +7,7 @@ from typing import Tuple, Type, TypeVar, Union
 
 from loguru import logger
 from transformers import PreTrainedTokenizer
+from typing_extensions import assert_never
 
 from aphrodite.common.config import (CacheConfig, DecodingConfig, DeviceConfig,
                                      EngineConfig, LoadConfig, LoRAConfig,
@@ -22,8 +23,7 @@ from aphrodite.common.sequence import (EmbeddingSequenceGroupOutput,
                                        ExecuteModelRequest, PoolerOutput,
                                        SamplerOutput, Sequence, SequenceGroup,
                                        SequenceGroupMetadata, SequenceStatus)
-from aphrodite.common.utils import (Counter, is_embedding_model_config,
-                                    is_encoder_decoder_model_config)
+from aphrodite.common.utils import Counter
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.metrics import (LoggingStatLogger, PrometheusStatLogger,
                                       StatLoggerBase, Stats)
@@ -34,9 +34,11 @@ from aphrodite.engine.output_processor.util import (
     create_output_by_sequence_group)
 from aphrodite.executor.executor_base import ExecutorBase
 from aphrodite.executor.ray_utils import initialize_ray_cluster
-from aphrodite.inputs import (INPUT_REGISTRY, LLMInputs, PromptInputs,
-                              get_prompt_type)
+from aphrodite.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
+                              LLMInputs, PromptInputs, SingletonPromptInputs)
+from aphrodite.inputs.parse import is_explicit_encoder_decoder_prompt
 from aphrodite.lora.request import LoRARequest
+from aphrodite.multimodal import MultiModalDataDict
 from aphrodite.processing.scheduler import (ScheduledSequenceGroup, Scheduler,
                                             SchedulerOutputs)
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
@@ -67,6 +69,11 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
 
 _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
 
+PromptComponents = Tuple[Optional[str], List[int],
+                         Optional[MultiModalDataDict]]
+DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
+                                Optional[MultiModalDataDict]]
+
 
 class AphroditeEngine:
     """An LLM engine that receives requests and generates texts.
@@ -472,7 +479,7 @@ class AphroditeEngine:
 
         return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
 
-    def _get_decoder_start_token_id(self, ) -> Optional[int]:
+    def _get_decoder_start_token_id(self) -> Optional[int]:
         '''
         Obtain the decoder start token id employed by an encoder/decoder
         model. Returns None for non-encoder/decoder models or if the
@@ -501,7 +508,7 @@ class AphroditeEngine:
     def _add_processed_request(
         self,
         request_id: str,
-        processed_inputs: LLMInputs,
+        processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
         params: Union[SamplingParams, PoolingParams],
         arrival_time: float,
         lora_request: Optional[LoRARequest],
@@ -561,11 +568,11 @@ class AphroditeEngine:
     def stop_remote_worker_execution_loop(self) -> None:
         self.model_executor.stop_remote_worker_execution_loop()
 
-    _LLMInputComponentsType = Tuple[str, List[int], ]
+    _LLMInputComponentsType = Tuple[str, List[int]]
 
     def _prepare_decoder_input_ids_for_generation(
         self,
-        decoder_input_ids: Optional[List[int]] = None,
+        decoder_input_ids: Optional[List[int]],
     ) -> List[int]:
         """
         Prepares `decoder_input_ids` for generation with encoder-decoder models.
@@ -580,14 +587,13 @@ class AphroditeEngine:
         * Processed token list
         """
 
-        decoder_start_token_id: Optional[int] = (
-            self._get_decoder_start_token_id())
+        decoder_start_token_id = self._get_decoder_start_token_id()
         assert decoder_start_token_id is not None
 
         if decoder_input_ids is None:
             # no decoder prompt input ->
             # use decoder_start_token_id as decoder_input_ids
-            (decoder_input_ids) = self._get_default_enc_dec_decoder_prompt()
+            decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
 
         if (len(decoder_input_ids) == 0
                 or decoder_input_ids[0] != decoder_start_token_id):
@@ -598,12 +604,11 @@ class AphroditeEngine:
     def _tokenize_prompt(
         self,
         prompt: str,
-        request_id: Optional[str] = None,
-        lora_request: Optional[str] = None,
+        request_id: str,
+        lora_request: Optional[LoRARequest],
     ) -> List[int]:
         '''
-        Wrapper around application of the model's
-        tokenizer.
+        Wrapper around application of the model's tokenizer.
         Arguments:
         * prompt
         * request_id
@@ -615,81 +620,68 @@ class AphroditeEngine:
         tokenizer = self.get_tokenizer_group("prompts must be None if "
                                              "skip_tokenizer_init is True")
 
-        prompt_token_ids = tokenizer.encode(request_id=request_id,
-                                            prompt=prompt,
-                                            lora_request=lora_request)
+        return tokenizer.encode(request_id=request_id,
+                                prompt=prompt,
+                                lora_request=lora_request)
 
-        return prompt_token_ids
-
-    def _extract_single_prompt_for_enc_dec_input(
+    def _extract_prompt_components(
         self,
-        inputs: Optional[PromptInputs],
-        request_id: Optional[str] = None,
-        ptype: Optional[str] = None,
-        is_encoder_prompt: bool = False,
-    ) -> Tuple[Optional[str], List[int]]:
+        inputs: SingletonPromptInputs,
+        request_id: str,
+        lora_request: Optional[LoRARequest] = None,
+    ) -> PromptComponents:
         '''
-        Only for encoder/decoder models:
-        Extract prompt & prompt_token_ids from any single
-        encoder or decoder input prompt. For encoder input prompts
-        in particular, also extract multi-modal data.
-        This function handles the following scenarios:
-        1. The user supplied a singleton encoder prompt
-          & the prompt/prompt-token-ids must be extracted.
-        2. The user supplied an explicit encoder/decoder
-          prompt & the prompt/prompt-token-ids must be
-          extracted from either the encoder and decoder prompts.
-        For decoder prompts in particular (scenario 2), special
-        processing is applied to the returned decoder token ids.
+        Extract the components of any single encoder or decoder input prompt.
         Arguments:
         * request_id
-        * ptype: str representation of the input prompt type.
-                 If `ptype` is `None`, assume that the prompt
-                 type is unknown and must be inferred. This is the
-                 case for ExplicitEncoderDecoder sub-prompts.
         * inputs: single encoder or decoder input prompt
-        * is_encoder_prompt: True if encoder input prompt.
-                             If False, decoder prompt tokens
-                             are preprocessed.
+        * lora_request: this is only valid for decoder prompts
         Returns:
         * prompt
         * prompt_token_ids
+        * multi_modal_data
         '''
-        prompt_token_ids = None
-        ptype = (get_prompt_type(inputs) if ptype is None else ptype)
 
-        if inputs is None:
-            prompt = None
-        elif ptype == 'str':
+        if isinstance(inputs, str):
             prompt = inputs
             prompt_token_ids = self._tokenize_prompt(
                 prompt,
                 request_id=request_id,
+                lora_request=lora_request,
             )
-        elif ptype == 'TokensPrompt':
-            prompt = None
-            prompt_token_ids = inputs['prompt_token_ids']
+            multi_modal_data = None
+        elif isinstance(inputs, dict):
+            if "prompt_token_ids" in inputs:
+                prompt = None
+                prompt_token_ids = inputs["prompt_token_ids"]
+            else:
+                # NOTE: This extra assignment is required to pass mypy
+                prompt = parsed_prompt = inputs["prompt"]
+                prompt_token_ids = self._tokenize_prompt(
+                    parsed_prompt,
+                    request_id=request_id,
+                    lora_request=lora_request,
+                )
+
+            multi_modal_data = inputs.get("multi_modal_data")
         else:
-            prompt = inputs['prompt']
-            prompt_token_ids = self._tokenize_prompt(
-                prompt,
-                request_id=request_id,
-            )
+            assert_never(inputs)
 
-        if not is_encoder_prompt:
-            # Apply special pre-processing to
-            # decoder prompts
-            prompt_token_ids = (self._prepare_decoder_input_ids_for_generation(
-                prompt_token_ids, ))
+        return prompt, prompt_token_ids, multi_modal_data
 
-        assert prompt_token_ids is not None
+    def _apply_prompt_adapter(
+        self,
+        prompt_token_ids: List[int],
+        prompt_adapter_request: Optional[PromptAdapterRequest],
+    ) -> List[int]:
+        if prompt_adapter_request:
+            prompt_token_ids = (
+                [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
+                + prompt_token_ids)
 
-        return (
-            prompt,
-            prompt_token_ids,
-        )
+        return prompt_token_ids
 
-    def _get_default_enc_dec_decoder_prompt(self, ) -> List[int]:
+    def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
         '''
         Specifically for encoder/decoder models:
         generate a default decoder prompt for when
@@ -718,18 +710,39 @@ class AphroditeEngine:
 
         bos_token_id = self._get_bos_token_id()
         assert bos_token_id is not None
-        prompt_token_ids: List[int] = [bos_token_id]
-        return prompt_token_ids
+        return [bos_token_id]
+
+    def _build_enc_dec_llm_inputs(
+        self,
+        encoder_comps: PromptComponents,
+        decoder_comps: DecoderPromptComponents,
+    ) -> EncoderDecoderLLMInputs:
+        encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
+        decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
+
+        if encoder_mm_data is not None or decoder_mm_data is not None:
+            raise ValueError("Multi-modal encoder-decoder models are "
+                             "not supported yet")
+
+        decoder_prompt_ids = (
+            self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
+
+        return EncoderDecoderLLMInputs(
+            prompt_token_ids=decoder_prompt_ids,
+            prompt=decoder_prompt,
+            encoder_prompt_token_ids=encoder_prompt_ids,
+            encoder_prompt=encoder_prompt,
+        )
 
     def _process_encoder_decoder_prompt(
         self,
         inputs: PromptInputs,
-        request_id: Optional[str] = None,
-    ) -> LLMInputs:
+        request_id: str,
+    ) -> EncoderDecoderLLMInputs:
         '''
         For encoder/decoder models only:
-        Process an input prompt
-        into an `LLMInputs` instance.
+        Process an input prompt into an
+        :class:`EncoderDecoderLLMInputs` instance.
         There are two types of input prompts:
         singleton prompts which carry only the
         encoder prompt, and explicit encoder/decoder
@@ -750,131 +763,98 @@ class AphroditeEngine:
         * inputs: an input prompt
         * request_id
         Returns:
-        * `LLMInputs` instance
+        * :class:`EncoderDecoderLLMInputs` instance
         '''
 
-        ptype = get_prompt_type(inputs)
-
-        # Obtain encoder and decoder prompt tokens. Note
-        # that, no matter what, the decoder
-        # prompt type is unknown.
-        if ptype == "ExplicitEncoderDecoder":
-            # If input is explicit encoder/decoder prompt,
-            # then it remains to be determined what type
-            # of encoder prompt we have
-            extracted_encoder_prompt = inputs.get('encoder_prompt')
-            encoder_ptype = None
-            # Extract decoder prompt from explicit
-            # encoder/decoder prompt
-            extracted_decoder_prompt = inputs.get('decoder_prompt')
+        encoder_comps: PromptComponents
+        decoder_comps: DecoderPromptComponents
+
+        if is_explicit_encoder_decoder_prompt(inputs):
+            encoder_comps = self._extract_prompt_components(
+                inputs["encoder_prompt"],
+                request_id=request_id,
+            )
+
+            if (decoder_input := inputs["decoder_prompt"]) is None:
+                decoder_comps = None, None, None
+            else:
+                decoder_comps = self._extract_prompt_components(
+                    decoder_input,
+                    request_id=request_id,
+                )
         else:
-            # If input is singleton encoder prompt, then
-            # we know the encoder prompt type
-            extracted_encoder_prompt = inputs
-            encoder_ptype = ptype
-            # Decoder prompt is always unknown if
-            # encoder/decoder prompt is not explicit
-            extracted_decoder_prompt = None
-
-        # Invoke helper function to obtain encoder
-        # prompt and prompt token ids, either from
-        # singleton encoder prompt or from the
-        # encoder sub-prompt of an explicit
-        # encoder/decode scenario 2), special
-        # processing is applied to the returned decoder token ids
-        (
-            encoder_prompt,
-            encoder_prompt_token_ids,
-        ) = self._extract_single_prompt_for_enc_dec_input(
-            extracted_encoder_prompt,
-            request_id=request_id,
-            ptype=encoder_ptype,
-            is_encoder_prompt=True,
-        )
+            encoder_comps = self._extract_prompt_components(
+                inputs,
+                request_id=request_id,
+            )
 
-        # Invoke helper method to obtain
-        # decoder prompt and prompt token ids.
-        #
-        # The helper method will detect the decoder
-        # prompt type.
-        #
-        # Helper method will also apply special
-        # preprocessing unique to decoder prompts.
-        (
-            decoder_prompt,
-            decoder_prompt_token_ids,
-        ) = self._extract_single_prompt_for_enc_dec_input(
-            extracted_decoder_prompt,
-            request_id=request_id,
-            ptype=None,
-            is_encoder_prompt=False,
-        )
+            decoder_comps = None, None, None
 
-        return LLMInputs(
-            prompt_token_ids=decoder_prompt_token_ids,
-            prompt=decoder_prompt,
-            encoder_prompt_token_ids=encoder_prompt_token_ids,
-            encoder_prompt=encoder_prompt,
-        )
+        return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
+
+    def _build_decoder_only_llm_inputs(
+        self,
+        prompt_comps: PromptComponents,
+        prompt_adapter_request: Optional[PromptAdapterRequest],
+    ) -> LLMInputs:
+        prompt, prompt_token_ids, multi_modal_data = prompt_comps
+
+        prompt_token_ids = self._apply_prompt_adapter(
+            prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
+
+        return LLMInputs(prompt_token_ids=prompt_token_ids,
+                         prompt=prompt,
+                         multi_modal_data=multi_modal_data)
 
     def _process_decoder_only_prompt(
         self,
-        inputs: PromptInputs,
+        inputs: SingletonPromptInputs,
+        request_id: str,
         lora_request: Optional[LoRARequest] = None,
-        request_id: Optional[str] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
     ) -> LLMInputs:
         '''
         For decoder-only models:
-        Process an input prompt
-        into an `LLMInputs` instance.
+        Process an input prompt into an :class:`LLMInputs` instance.
         Arguments:
         * inputs: input prompt
-        * lora_request
         * request_id
+        * lora_request
         * prompt_adapter_request
         Returns:
-        * `LLMInputs` instance
+        * :class:`LLMInputs` instance
         '''
-        if isinstance(inputs, str):
-            inputs = {"prompt": inputs}
-        prompt = inputs.get("prompt")
-
-        if "prompt_token_ids" not in inputs:
-            prompt_token_ids = self._tokenize_prompt(
-                prompt,
-                request_id=request_id,
-                lora_request=lora_request,
-            )
-        else:
-            prompt_token_ids = inputs["prompt_token_ids"]
 
-        if prompt_adapter_request:
-            prompt_token_ids = (
-                [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
-                + prompt_token_ids)
+        prompt_comps = self._extract_prompt_components(
+            inputs,
+            request_id=request_id,
+            lora_request=lora_request,
+        )
 
-        return LLMInputs(prompt_token_ids=prompt_token_ids,
-                         prompt=prompt,
-                         multi_modal_data=inputs.get("multi_modal_data"))
+        return self._build_decoder_only_llm_inputs(
+            prompt_comps,
+            prompt_adapter_request=prompt_adapter_request,
+        )
 
     def process_model_inputs(
         self,
-        request_id: str,
         inputs: PromptInputs,
+        request_id: str,
         lora_request: Optional[LoRARequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
-    ) -> LLMInputs:
+    ) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
 
         if self.is_encoder_decoder_model():
             # Encoder-decoder model requires special mapping of
             # input prompts to encoder & decoder
-
             model_inputs = self._process_encoder_decoder_prompt(
                 inputs,
                 request_id=request_id,
             )
         else:
+            if is_explicit_encoder_decoder_prompt(inputs):
+                raise ValueError("Cannot pass encoder-decoder prompt "
+                                 "to decoder-only models")
             # Decoder-only operation
             model_inputs = self._process_decoder_only_prompt(
                 inputs,
@@ -945,10 +925,11 @@ class AphroditeEngine:
             arrival_time = time.time()
 
         processed_inputs = self.process_model_inputs(
+            inputs,
             request_id=request_id,
-            inputs=inputs,
             lora_request=lora_request,
-            prompt_adapter_request=prompt_adapter_request)
+            prompt_adapter_request=prompt_adapter_request,
+        )
 
         self._add_processed_request(
             request_id=request_id,
@@ -1450,10 +1431,10 @@ class AphroditeEngine:
         self.model_executor.check_health()
 
     def is_encoder_decoder_model(self):
-        return is_encoder_decoder_model_config(self.model_config)
+        return self.model_config.is_encoder_decoder_model
 
     def is_embedding_model(self):
-        return is_embedding_model_config(self.model_config)
+        return self.model_config.is_embedding_model
 
 
 setup_logger()

+ 131 - 24
aphrodite/engine/async_aphrodite.py

@@ -7,6 +7,7 @@ from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Optional,
 
 from loguru import logger
 from transformers import PreTrainedTokenizer
+from typing_extensions import assert_never
 
 from aphrodite.common.config import (DecodingConfig, EngineConfig, LoRAConfig,
                                      ModelConfig, ParallelConfig,
@@ -15,13 +16,17 @@ from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
 from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
-from aphrodite.engine.aphrodite_engine import AphroditeEngine
+from aphrodite.engine.aphrodite_engine import (AphroditeEngine,
+                                               DecoderPromptComponents,
+                                               PromptComponents)
 from aphrodite.engine.args_tools import AsyncEngineArgs
 from aphrodite.engine.async_timeout import asyncio_timeout
 from aphrodite.engine.metrics import StatLoggerBase
 from aphrodite.executor.executor_base import ExecutorAsyncBase
 from aphrodite.executor.ray_utils import initialize_ray_cluster, ray
-from aphrodite.inputs import LLMInputs, PromptInputs
+from aphrodite.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
+                              SingletonPromptInputs)
+from aphrodite.inputs.parse import is_explicit_encoder_decoder_prompt
 from aphrodite.lora.request import LoRARequest
 from aphrodite.processing.scheduler import SchedulerOutputs
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
@@ -290,38 +295,138 @@ class _AsyncAphrodite(AphroditeEngine):
         """Stop the remote worker execution loop."""
         await self.model_executor.stop_remote_worker_execution_loop_async()
 
-    async def process_model_inputs_async(
+    async def _tokenize_prompt_async(
+        self,
+        prompt: str,
+        request_id: str,
+        lora_request: Optional[LoRARequest],
+    ) -> List[int]:
+        """Async version of :meth:`_tokenize_prompt`."""
+        tokenizer = self.get_tokenizer_group("prompts must be None if "
+                                             "skip_tokenizer_init is True")
+
+        return await tokenizer.encode_async(request_id=request_id,
+                                            prompt=prompt,
+                                            lora_request=lora_request)
+
+    async def _extract_prompt_components_async(
         self,
+        inputs: SingletonPromptInputs,
         request_id: str,
+        lora_request: Optional[LoRARequest] = None,
+    ) -> PromptComponents:
+        """Async version of :meth:`_extract_prompt_components`."""
+        if isinstance(inputs, str):
+            prompt = inputs
+            prompt_token_ids = await self._tokenize_prompt_async(
+                prompt,
+                request_id=request_id,
+                lora_request=lora_request,
+            )
+            multi_modal_data = None
+        elif isinstance(inputs, dict):
+            if "prompt_token_ids" in inputs:
+                prompt = None
+                prompt_token_ids = inputs["prompt_token_ids"]
+            else:
+                # NOTE: This extra assignment is required to pass mypy
+                prompt = parsed_prompt = inputs["prompt"]
+                prompt_token_ids = await self._tokenize_prompt_async(
+                    parsed_prompt,
+                    request_id=request_id,
+                    lora_request=lora_request,
+                )
+
+            multi_modal_data = inputs.get("multi_modal_data")
+        else:
+            assert_never(inputs)
+
+        return prompt, prompt_token_ids, multi_modal_data
+
+    async def _process_encoder_decoder_prompt_async(
+        self,
         inputs: PromptInputs,
+        request_id: str,
+    ) -> EncoderDecoderLLMInputs:
+        """Async version of :meth:`_process_encoder_decoder_prompt`."""
+        encoder_comps: PromptComponents
+        decoder_comps: DecoderPromptComponents
+
+        if is_explicit_encoder_decoder_prompt(inputs):
+            encoder_task = self._extract_prompt_components_async(
+                inputs["encoder_prompt"],
+                request_id=request_id,
+            )
+
+            if (decoder_input := inputs["decoder_prompt"]) is None:
+                encoder_comps = await encoder_task
+                decoder_comps = None, None, None
+            else:
+                decoder_task = self._extract_prompt_components_async(
+                    decoder_input,
+                    request_id=request_id,
+                )
+
+                encoder_comps, decoder_comps = await asyncio.gather(
+                    encoder_task, decoder_task)
+        else:
+            encoder_comps = await self._extract_prompt_components_async(
+                inputs,
+                request_id=request_id,
+            )
+
+            decoder_comps = None, None, None
+
+        return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
+
+    async def _process_decoder_only_prompt_async(
+        self,
+        inputs: SingletonPromptInputs,
+        request_id: str,
         lora_request: Optional[LoRARequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
     ) -> LLMInputs:
-        if isinstance(inputs, str):
-            inputs = {"prompt": inputs}
+        """Async version of :meth:`_process_decoder_only_prompt`."""
+        prompt_comps = await self._extract_prompt_components_async(
+            inputs,
+            request_id=request_id,
+            lora_request=lora_request,
+        )
 
-        if "prompt_token_ids" not in inputs:
-            tokenizer = self.get_tokenizer_group("prompts must be None if "
-                                                 "skip_tokenizer_init is True")
+        return self._build_decoder_only_llm_inputs(
+            prompt_comps,
+            prompt_adapter_request=prompt_adapter_request,
+        )
 
-            prompt_token_ids = await tokenizer.encode_async(
+    async def process_model_inputs_async(
+        self,
+        inputs: PromptInputs,
+        request_id: str,
+        lora_request: Optional[LoRARequest] = None,
+        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+    ) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
+        """Async version of :meth:`process_model_inputs`."""
+        if self.is_encoder_decoder_model():
+            # Encoder-decoder model requires special mapping of
+            # input prompts to encoder & decoder
+            model_inputs = await self._process_encoder_decoder_prompt_async(
+                inputs,
                 request_id=request_id,
-                prompt=inputs["prompt"],
-                lora_request=lora_request)
+            )
         else:
-            prompt_token_ids = inputs["prompt_token_ids"]
+            if is_explicit_encoder_decoder_prompt(inputs):
+                raise ValueError("Cannot pass encoder-decoder prompt "
+                                 "to decoder-only models")
 
-        if prompt_adapter_request:
-            prompt_token_ids = [
-                0
-            ] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \
-                prompt_token_ids
-
-        llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
-                               prompt=inputs.get("prompt"),
-                               multi_modal_data=inputs.get("multi_modal_data"))
+            # Decoder-only operation
+            model_inputs = await self._process_decoder_only_prompt_async(
+                inputs,
+                request_id=request_id,
+                lora_request=lora_request,
+                prompt_adapter_request=prompt_adapter_request,
+            )
 
-        return self.input_processor(llm_inputs)
+        return self.input_processor(model_inputs)
 
     async def add_request_async(
         self,
@@ -332,6 +437,7 @@ class _AsyncAphrodite(AphroditeEngine):
         lora_request: Optional[LoRARequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
     ) -> None:
+        """Async version of :meth:`add_request`."""
         if lora_request is not None and not self.lora_config:
             raise ValueError(f"Got lora_request {lora_request} but LoRA is "
                              "not enabled!")
@@ -339,10 +445,11 @@ class _AsyncAphrodite(AphroditeEngine):
             arrival_time = time.time()
 
         processed_inputs = await self.process_model_inputs_async(
+            inputs,
             request_id=request_id,
-            inputs=inputs,
             lora_request=lora_request,
-            prompt_adapter_request=prompt_adapter_request)
+            prompt_adapter_request=prompt_adapter_request,
+        )
 
         self._add_processed_request(
             request_id=request_id,

+ 10 - 11
aphrodite/inputs/__init__.py

@@ -1,7 +1,7 @@
-from .data import (ExplicitEncoderDecoderPrompt, LLMInputs, ParsedText,
-                   ParsedTokens, PromptInputs, SingletonPromptInputs,
-                   TextPrompt, TokensPrompt, get_prompt_type,
-                   is_valid_encoder_decoder_llm_inputs, parse_and_batch_prompt)
+from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
+                   LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
+                   TokensPrompt, build_explicit_enc_dec_prompt,
+                   to_enc_dec_tuple_list, zip_enc_dec_prompts)
 from .registry import InputContext, InputRegistry
 
 INPUT_REGISTRY = InputRegistry()
@@ -14,18 +14,17 @@ See also:
 """
 
 __all__ = [
-    "ParsedText",
-    "ParsedTokens",
-    "parse_and_batch_prompt",
     "TextPrompt",
     "TokensPrompt",
     "PromptInputs",
+    "SingletonPromptInputs",
+    "ExplicitEncoderDecoderPrompt",
     "LLMInputs",
+    "EncoderDecoderLLMInputs",
+    "build_explicit_enc_dec_prompt",
+    "to_enc_dec_tuple_list",
+    "zip_enc_dec_prompts",
     "INPUT_REGISTRY",
     "InputContext",
     "InputRegistry",
-    "get_prompt_type",
-    "is_valid_encoder_decoder_llm_inputs",
-    "ExplicitEncoderDecoderPrompt",
-    "SingletonPromptInputs",
 ]

+ 65 - 123
aphrodite/inputs/data.py

@@ -1,70 +1,12 @@
-from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence,
-                    TypedDict, Union, cast, overload)
+from typing import (TYPE_CHECKING, Generic, Iterable, List, Optional, Tuple,
+                    Union)
 
-from typing_extensions import NotRequired
+from typing_extensions import NotRequired, TypedDict, TypeVar
 
 if TYPE_CHECKING:
     from aphrodite.multimodal import MultiModalDataDict
 
 
-class ParsedText(TypedDict):
-    content: str
-    is_tokens: Literal[False]
-
-
-class ParsedTokens(TypedDict):
-    content: List[int]
-    is_tokens: Literal[True]
-
-
-@overload
-def parse_and_batch_prompt(
-        prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
-    ...
-
-
-@overload
-def parse_and_batch_prompt(
-        prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
-    ...
-
-
-def parse_and_batch_prompt(
-    prompt: Union[str, List[str], List[int], List[List[int]]],
-) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
-    if isinstance(prompt, str):
-        # case 1: a string
-        return [ParsedText(content=prompt, is_tokens=False)]
-
-    if isinstance(prompt, list):
-        if len(prompt) == 0:
-            raise ValueError("please provide at least one prompt")
-
-        if isinstance(prompt[0], str):
-            # case 2: array of strings
-            return [
-                ParsedText(content=elem, is_tokens=False)
-                for elem in cast(List[str], prompt)
-            ]
-        if isinstance(prompt[0], int):
-            # case 3: array of tokens
-            elem = cast(List[int], prompt)
-            return [ParsedTokens(content=elem, is_tokens=True)]
-        if isinstance(prompt[0], list):
-            if len(prompt[0]) == 0:
-                raise ValueError("please provide at least one prompt")
-
-            if isinstance(prompt[0][0], int):
-                # case 4: array of token arrays
-                return [
-                    ParsedTokens(content=elem, is_tokens=True)
-                    for elem in cast(List[List[int]], prompt)
-                ]
-
-    raise ValueError("prompt must be a string, array of strings, "
-                     "array of tokens, or array of token arrays")
-
-
 class TextPrompt(TypedDict):
     """Schema for a text prompt."""
 
@@ -109,7 +51,18 @@ where the decoder-prompt is not specified explicitly, or
 more than one prompt, i.e. ExplicitEncoderDecoderPrompt
 """
 
-class ExplicitEncoderDecoderPrompt(TypedDict):
+_T1_co = TypeVar("_T1_co",
+                 bound=SingletonPromptInputs,
+                 default=SingletonPromptInputs,
+                 covariant=True)
+_T2_co = TypeVar("_T2_co",
+                 bound=SingletonPromptInputs,
+                 default=SingletonPromptInputs,
+                 covariant=True)
+
+
+# TODO: Make fields ReadOnly once mypy supports it
+class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
     """Represents an encoder/decoder model input prompt,
     comprising an explicit encoder prompt and a 
     decoder prompt.
@@ -125,9 +78,9 @@ class ExplicitEncoderDecoderPrompt(TypedDict):
     must be SingletonPromptInputs instances.
     """
 
-    encoder_prompt: SingletonPromptInputs
+    encoder_prompt: _T1_co
 
-    decoder_prompt: SingletonPromptInputs
+    decoder_prompt: Optional[_T2_co]
 
 
 PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt]
@@ -141,56 +94,12 @@ both decoder-only and encoder/decoder input types:
 """
 
 
-def _has_required_keys(
-    d: dict,
-    required_keys: set,
-) -> bool:
-    return required_keys.issubset(d.keys())
-
-
-def get_prompt_type(prompt: Optional[PromptInputs]) -> Optional[str]:
-    """
-    Get the type-name of the prompt argument instance, given that
-    isinstance() cannot apply to TypedDict subclasses directly.
-    If the prompt is None, return 'None' as the type name.
-    Arguments:
-    * prompt: LLM input prompt or None
-    Returns:
-    * String representation of prompt type
-    """
-
-    if prompt is None:
-        return 'None'
-
-    required_keys_dict = {
-        'TextPrompt': {'prompt'},
-        'TokensPrompt': {'prompt_token_ids'},
-        'ExplicitEncoderDecoder': {'encoder_prompt', 'decoder_prompt'},
-    }
-
-    if isinstance(prompt, dict):
-        for (ptype, required_keys) in required_keys_dict.items():
-            # Ignore type checking in the conditional below because type
-            # checker does not understand that is_dict(prompt) narrows
-            # down the possible types
-            if _has_required_keys(
-                    prompt,  # type: ignore
-                    required_keys):
-                return ptype
-
-        raise ValueError(f"Invalid prompt {prompt}, valid types are "
-                         "required_keys_dict={required_keys_dict}")
-
-    if isinstance(prompt, str):
-        return "str"
-
-    raise ValueError(f"Invalid prompt {prompt}")
-
-
 class LLMInputs(TypedDict):
     """
     The inputs in :class:`~aphrodite.AphroditeEngine` before they are
     passed to the model executor.
+
+    This specifies the data required for decoder-only models.
     """
 
     prompt_token_ids: List[int]
@@ -201,7 +110,20 @@ class LLMInputs(TypedDict):
     The original prompt text corresponding to the token IDs, if available.
     """
 
-    encoder_prompt_token_ids: NotRequired[List[int]]
+    multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
+    """
+    Optional multi-modal data to pass to the model,
+    if the model supports it.
+    """
+
+
+class EncoderDecoderLLMInputs(LLMInputs):
+    """
+    The inputs in :class:`~aphrodite.AphroditeEngine` before they are
+    passed to the model executor.
+    This specifies the required data for encoder-decoder models.
+    """
+    encoder_prompt_token_ids: List[int]
     """The token IDs of the encoder prompt."""
 
     encoder_prompt: NotRequired[Optional[str]]
@@ -210,20 +132,40 @@ class LLMInputs(TypedDict):
     available.
     """
 
-    multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
-    """
-    Optional multi-modal data to pass to the model,
-    if the model supports it.
-    """
+
+_T1 = TypeVar("_T1",
+              bound=SingletonPromptInputs,
+              default=SingletonPromptInputs)
+_T2 = TypeVar("_T2",
+              bound=SingletonPromptInputs,
+              default=SingletonPromptInputs)
 
 
-def is_valid_encoder_decoder_llm_inputs(inputs: LLMInputs) -> bool:
+def build_explicit_enc_dec_prompt(
+    encoder_prompt: _T1,
+    decoder_prompt: Optional[_T2],
+) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
+    return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
+                                        decoder_prompt=decoder_prompt)
+
+
+def zip_enc_dec_prompts(
+    enc_prompts: Iterable[_T1],
+    dec_prompts: Iterable[Optional[_T2]],
+) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
     """
-    Return True if the LLMInputs instance has the correct configuration
-    for encoder/decoder.
+    Zip encoder and decoder prompts together into a list of
+    :class:`ExplicitEncoderDecoderPrompt` instances.
     """
+    return [
+        build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
+        for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts)
+    ]
+
 
-    # True if encoder prompt token ids field exists &
-    # is not None
-    return ('encoder_prompt_token_ids' in inputs
-            and inputs['encoder_prompt_token_ids'] is not None)
+def to_enc_dec_tuple_list(
+    enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]],
+) -> List[Tuple[_T1, Optional[_T2]]]:
+    return [(enc_dec_prompt["encoder_prompt"],
+             enc_dec_prompt["decoder_prompt"])
+            for enc_dec_prompt in enc_dec_prompts]

+ 75 - 0
aphrodite/inputs/parse.py

@@ -0,0 +1,75 @@
+from typing import List, Literal, Sequence, TypedDict, Union, overload
+
+from typing_extensions import TypeIs
+
+from aphrodite.common.utils import is_list_of
+
+from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
+                   LLMInputs, PromptInputs)
+
+
+class ParsedText(TypedDict):
+    content: str
+    is_tokens: Literal[False]
+
+
+class ParsedTokens(TypedDict):
+    content: List[int]
+    is_tokens: Literal[True]
+
+
+@overload
+def parse_and_batch_prompt(
+        prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
+    ...
+
+
+@overload
+def parse_and_batch_prompt(
+        prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
+    ...
+
+
+def parse_and_batch_prompt(
+    prompt: Union[str, List[str], List[int], List[List[int]]],
+) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
+    if isinstance(prompt, str):
+        # case 1: a string
+        return [ParsedText(content=prompt, is_tokens=False)]
+
+    if isinstance(prompt, list):
+        if len(prompt) == 0:
+            raise ValueError("please provide at least one prompt")
+
+        if is_list_of(prompt, str):
+            # case 2: array of strings
+            return [
+                ParsedText(content=elem, is_tokens=False) for elem in prompt
+            ]
+        if is_list_of(prompt, int):
+            # case 3: array of tokens
+            return [ParsedTokens(content=prompt, is_tokens=True)]
+        if is_list_of(prompt, list):
+            if len(prompt[0]) == 0:
+                raise ValueError("please provide at least one prompt")
+
+            if is_list_of(prompt[0], int):
+                # case 4: array of token arrays
+                return [
+                    ParsedTokens(content=elem, is_tokens=True)
+                    for elem in prompt
+                ]
+
+    raise ValueError("prompt must be a string, array of strings, "
+                     "array of tokens, or array of token arrays")
+
+
+def is_explicit_encoder_decoder_prompt(
+        inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]:
+    return isinstance(inputs, dict) and "encoder_prompt" in inputs
+
+
+def is_valid_encoder_decoder_llm_inputs(
+    inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
+) -> TypeIs[EncoderDecoderLLMInputs]:
+    return "encoder_prompt_token_ids" in inputs

+ 11 - 11
aphrodite/modeling/models/interfaces.py

@@ -2,7 +2,7 @@ from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
                     Union, overload, runtime_checkable)
 
 from loguru import logger
-from typing_extensions import TypeGuard
+from typing_extensions import TypeIs
 
 from aphrodite.common.config import (LoRAConfig, MultiModalConfig,
                                      SchedulerConfig)
@@ -36,18 +36,18 @@ class _SupportsVisionType(Protocol):
 
 
 @overload
-def supports_vision(model: Type[object]) -> TypeGuard[Type[SupportsVision]]:
+def supports_vision(model: Type[object]) -> TypeIs[Type[SupportsVision]]:
     ...
 
 
 @overload
-def supports_vision(model: object) -> TypeGuard[SupportsVision]:
+def supports_vision(model: object) -> TypeIs[SupportsVision]:
     ...
 
 
 def supports_vision(
     model: Union[Type[object], object],
-) -> Union[TypeGuard[Type[SupportsVision]], TypeGuard[SupportsVision]]:
+) -> Union[TypeIs[Type[SupportsVision]], TypeIs[SupportsVision]]:
     if isinstance(model, type):
         return isinstance(model, _SupportsVisionType)
 
@@ -93,18 +93,18 @@ class _SupportsLoRAType(Protocol):
 
 
 @overload
-def supports_lora(model: Type[object]) -> TypeGuard[Type[SupportsLoRA]]:
+def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]:
     ...
 
 
 @overload
-def supports_lora(model: object) -> TypeGuard[SupportsLoRA]:
+def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
     ...
 
 
 def supports_lora(
     model: Union[Type[object], object],
-) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
+) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
     result = _supports_lora(model)
 
     if not result:
@@ -134,7 +134,7 @@ def supports_lora(
 
 def _supports_lora(
     model: Union[Type[object], object],
-) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
+) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
     if isinstance(model, type):
         return isinstance(model, _SupportsLoRAType)
 
@@ -169,18 +169,18 @@ class _HasInnerStateType(Protocol):
 
 
 @overload
-def has_inner_state(model: object) -> TypeGuard[HasInnerState]:
+def has_inner_state(model: object) -> TypeIs[HasInnerState]:
     ...
 
 
 @overload
-def has_inner_state(model: Type[object]) -> TypeGuard[Type[HasInnerState]]:
+def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]:
     ...
 
 
 def has_inner_state(
     model: Union[Type[object], object]
-) -> Union[TypeGuard[Type[HasInnerState]], TypeGuard[HasInnerState]]:
+) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]:
     if isinstance(model, type):
         return isinstance(model, _HasInnerStateType)
 

+ 4 - 2
aphrodite/multimodal/image.py

@@ -7,6 +7,7 @@ from PIL import Image
 from transformers import PreTrainedTokenizerBase
 
 from aphrodite.common.config import ModelConfig
+from aphrodite.common.utils import is_list_of
 from aphrodite.inputs.registry import InputContext
 from aphrodite.transformers_utils.image_processor import get_image_processor
 from aphrodite.transformers_utils.tokenizer import get_tokenizer
@@ -109,7 +110,8 @@ class ImagePlugin(MultiModalPlugin):
     def _default_input_mapper(self, ctx: InputContext,
                               data: object) -> MultiModalInputs:
         model_config = ctx.model_config
-        if isinstance(data, (Image.Image, list)):
+
+        if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
             image_processor = self._get_hf_image_processor(model_config)
             if image_processor is None:
                 raise RuntimeError("No HuggingFace processor is available"
@@ -123,7 +125,7 @@ class ImagePlugin(MultiModalPlugin):
                 raise
 
             return MultiModalInputs(batch_data)
-        elif isinstance(data, torch.Tensor):
+        elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor):
             raise NotImplementedError("Embeddings input is not supported yet")
 
         raise TypeError(f"Invalid image type: {type(data)}")

+ 2 - 4
aphrodite/task_handler/worker.py

@@ -13,8 +13,6 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      ParallelConfig, PromptAdapterConfig,
                                      SchedulerConfig, SpeculativeConfig)
 from aphrodite.common.sequence import ExecuteModelRequest
-from aphrodite.common.utils import (is_embedding_model_config,
-                                    is_encoder_decoder_model_config)
 from aphrodite.distributed import (ensure_model_parallel_initialized,
                                    get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
@@ -121,10 +119,10 @@ class Worker(LocalOrDistributedWorkerBase):
         self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
 
     def _is_encoder_decoder_model(self):
-        return is_encoder_decoder_model_config(self.model_config)
+        return self.model_config.is_encoder_decoder_model
 
     def _is_embedding_model(self):
-        return is_embedding_model_config(self.model_config)
+        return self.model_config.is_embedding_model
 
     def init_device(self) -> None:
         if self.device_config.device.type == "cuda":

+ 3 - 4
examples/offline_inference/encoder_decoder_inference.py

@@ -1,9 +1,8 @@
 """Prompting encoder-decoder models, specifically the BART model."""
 
 from aphrodite import LLM, SamplingParams
-from aphrodite.common.utils import zip_enc_dec_prompt_lists
 from aphrodite.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
-                              TokensPrompt)
+                              TokensPrompt, zip_enc_dec_prompts)
 
 dtype = "float"
 
@@ -59,9 +58,9 @@ enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
 )
 
 # - Finally, here's a useful helper function for zipping encoder and
-#   decoder prompt lists together into a list of ExplicitEncoderDecoderPrompt
+#   decoder prompts together into a list of ExplicitEncoderDecoderPrompt
 #   instances
-zipped_prompt_list = zip_enc_dec_prompt_lists(
+zipped_prompt_list = zip_enc_dec_prompts(
     ['An encoder prompt', 'Another encoder prompt'],
     ['A decoder prompt', 'Another decoder prompt'])
 

+ 1 - 1
requirements-common.txt

@@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
 tiktoken >= 0.6.0
 lm-format-enforcer == 0.10.3
 outlines >= 0.0.43, < 0.1
-typing_extensions
+typing_extensions >= 4.10
 filelock >= 3.10.4
 pyzmq
 scipy # for quip

+ 1 - 1
requirements-lint.txt

@@ -8,7 +8,7 @@ isort==5.13.2
 clang-format==18.1.5
 
 # type checking
-mypy==1.9.0
+mypy==1.11.1
 types-PyYAML
 types-requests
 types-setuptools