Sfoglia il codice sorgente

chore: remove `image_input_type` from VLM config

AlpinDale 7 mesi fa
parent
commit
3a0fdf7b9b

+ 1 - 33
aphrodite/common/config.py

@@ -1329,28 +1329,12 @@ class LoRAConfig:
             raise ValueError("LoRA is not supported with chunked prefill yet.")
 
 
+# TODO: Replace with MultiModalConfig.
 @dataclass
 class VisionLanguageConfig:
     """Configs the input data format and how models should run for
     vision language models."""
 
-    class ImageInputType(enum.Enum):
-        """Image input type into the vision language model.
-
-        An image roughly goes through the following transformation:
-        Raw image --> pixel values --> image features --> image embeddings.
-
-        The difference between different image input types is where the
-        image encoder (pixel values --> image features) is run.
-        Different image input types also correspond to different tensor shapes.
-
-        For example, for Llava, PIXEL_VALUES: (1, 3, 336, 336).
-        IMAGE_FEATURES: (1, 576, 1024).
-        """
-        PIXEL_VALUES = enum.auto()
-        IMAGE_FEATURES = enum.auto()
-
-    image_input_type: ImageInputType
     # The input id corresponding to image token.
     image_token_id: int
     # Used for running `run_prefill_max_token`.
@@ -1358,20 +1342,6 @@ class VisionLanguageConfig:
     # worst case scenario (biggest supported resolution).
     image_input_shape: tuple
     image_feature_size: int
-    # The image processor to load from HuggingFace
-    image_processor: Optional[str]
-    image_processor_revision: Optional[str]
-
-    @classmethod
-    def get_image_input_enum_type(
-            cls, value: str) -> "VisionLanguageConfig.ImageInputType":
-        """Get the image input type from a string."""
-        try:
-            return cls.ImageInputType[value.upper()]
-        except KeyError as e:
-            raise ValueError(f"{value} is not a valid choice. "
-                             f"Expecting to choose from "
-                             f"{[x.name for x in cls.ImageInputType]}.") from e
 
     def as_cli_args_dict(self) -> Dict[str, Any]:
         """Flatten vision language config to pure args.
@@ -1387,8 +1357,6 @@ class VisionLanguageConfig:
             else:
                 result[f.name] = value
 
-        result["disable_image_processor"] = self.image_processor is None
-
         return result
 
 

+ 4 - 4
aphrodite/common/sequence.py

@@ -14,7 +14,7 @@ from aphrodite.lora.request import LoRARequest
 
 if TYPE_CHECKING:
     from aphrodite.inputs import LLMInputs
-    from aphrodite.multimodal import MultiModalData
+    from aphrodite.multimodal import MultiModalDataDict
     from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
 
 
@@ -280,7 +280,7 @@ class Sequence:
         return self.inputs["prompt_token_ids"]
 
     @property
-    def multi_modal_data(self) -> Optional["MultiModalData"]:
+    def multi_modal_data(self) -> Optional["MultiModalDataDict"]:
         return self.inputs.get("multi_modal_data")
 
     @property
@@ -457,7 +457,7 @@ class SequenceGroup:
         return next(iter(self.seqs_dict.values())).prompt_token_ids
 
     @property
-    def multi_modal_data(self) -> Optional["MultiModalData"]:
+    def multi_modal_data(self) -> Optional["MultiModalDataDict"]:
         # All sequences in the group should have the same multi-modal data.
         # We use the multi-modal data of an arbitrary sequence.
         return next(iter(self.seqs_dict.values())).multi_modal_data
@@ -639,7 +639,7 @@ class SequenceGroupMetadata:
         lora_request: Optional[LoRARequest] = None,
         computed_block_nums: Optional[List[int]] = None,
         state: Optional[SequenceGroupState] = None,
-        multi_modal_data: Optional["MultiModalData"] = None,
+        multi_modal_data: Optional["MultiModalDataDict"] = None,
         encoder_seq_data: Optional[SequenceData] = None,
         cross_block_table: Optional[List[int]] = None,
     ) -> None:

+ 4 - 52
aphrodite/engine/args_tools.py

@@ -1,7 +1,6 @@
 import argparse
 import dataclasses
 import json
-import warnings
 from dataclasses import dataclass
 from typing import List, Optional, Tuple, Union
 
@@ -78,13 +77,9 @@ class EngineArgs:
     model_loader_extra_config: Optional[dict] = None
     preemption_mode: Optional[str] = None
     # Related to Vision-language models such as llava
-    image_input_type: Optional[str] = None
     image_token_id: Optional[int] = None
     image_input_shape: Optional[str] = None
     image_feature_size: Optional[int] = None
-    image_processor: Optional[str] = None
-    image_processor_revision: Optional[str] = None
-    disable_image_processor: bool = False
     # Scheduler config
     scheduler_delay_factor: float = 0.0
     enable_chunked_prefill: bool = False
@@ -112,14 +107,6 @@ class EngineArgs:
     @staticmethod
     def add_cli_args_for_vlm(
             parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
-        parser.add_argument('--image-input-type',
-                            type=str,
-                            default=None,
-                            choices=[
-                                t.name.lower()
-                                for t in VisionLanguageConfig.ImageInputType
-                            ],
-                            help=('The image input type passed.'))
         parser.add_argument('--image-token-id',
                             type=int,
                             default=None,
@@ -135,24 +122,6 @@ class EngineArgs:
             type=int,
             default=None,
             help=('The image feature size along the context dimension.'))
-        parser.add_argument(
-            '--image-processor',
-            type=str,
-            default=EngineArgs.image_processor,
-            help='Name or path of the huggingface image processor to use. '
-            'If unspecified, model name or path will be used.')
-        parser.add_argument(
-            '--image-processor-revision',
-            type=str,
-            default=None,
-            help='Revision of the huggingface image processor version to use. '
-            'It can be a branch name, a tag name, or a commit id. '
-            'If unspecified, will use the default version.')
-        parser.add_argument(
-            '--disable-image-processor',
-            action='store_true',
-            help='Disables the use of image processor, even if one is defined '
-            'for the model on huggingface.')
 
         return parser
 
@@ -742,33 +711,16 @@ class EngineArgs:
             raise ValueError(
                 "BitsAndBytes load format and QLoRA adapter only support "
                 f"'bitsandbytes' quantization, but got {self.quantization}")
-        if self.image_input_type:
-            if (not self.image_token_id or not self.image_input_shape
-                    or not self.image_feature_size):
+        if self.image_token_id is not None:
+            if (not self.image_input_shape or not self.image_feature_size):
                 raise ValueError(
-                    'Specify `image_token_id`, `image_input_shape` and '
-                    '`image_feature_size` together with `image_input_type`.')
-
-            if self.image_processor is None:
-                self.image_processor = self.model
-            if self.disable_image_processor:
-                if self.image_processor != self.model:
-                    warnings.warn(
-                        "You've specified an image processor "
-                        f"({self.image_processor}) but also disabled "
-                        "it via `--disable-image-processor`.",
-                        stacklevel=2)
-
-                self.image_processor = None
+                    'Specify `image_input_shape` and '
+                    '`image_feature_size` together with `image_token_id`.')
 
             vision_language_config = VisionLanguageConfig(
-                image_input_type=VisionLanguageConfig.
-                get_image_input_enum_type(self.image_input_type),
                 image_token_id=self.image_token_id,
                 image_input_shape=str_to_int_tuple(self.image_input_shape),
                 image_feature_size=self.image_feature_size,
-                image_processor=self.image_processor,
-                image_processor_revision=self.image_processor_revision,
             )
         else:
             vision_language_config = None

+ 5 - 5
aphrodite/inputs/data.py

@@ -4,7 +4,7 @@ from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence,
 from typing_extensions import NotRequired
 
 if TYPE_CHECKING:
-    from aphrodite.multimodal import MultiModalData
+    from aphrodite.multimodal import MultiModalDataDict
 
 
 class ParsedText(TypedDict):
@@ -71,7 +71,7 @@ class TextPrompt(TypedDict):
     prompt: str
     """The input text to be tokenized before passing to the model."""
 
-    multi_modal_data: NotRequired["MultiModalData"]
+    multi_modal_data: NotRequired["MultiModalDataDict"]
     """
     Optional multi-modal data to pass to the model,
     if the model supports it.
@@ -84,7 +84,7 @@ class TokensPrompt(TypedDict):
     prompt_token_ids: List[int]
     """A list of token IDs to pass to the model."""
 
-    multi_modal_data: NotRequired["MultiModalData"]
+    multi_modal_data: NotRequired["MultiModalDataDict"]
     """
     Optional multi-modal data to pass to the model,
     if the model supports it.
@@ -102,7 +102,7 @@ class TextTokensPrompt(TypedDict):
     prompt_token_ids: List[int]
     """The token IDs of the prompt."""
 
-    multi_modal_data: NotRequired["MultiModalData"]
+    multi_modal_data: NotRequired["MultiModalDataDict"]
     """
     Optional multi-modal data to pass to the model,
     if the model supports it.
@@ -135,7 +135,7 @@ class LLMInputs(TypedDict):
     The original prompt text corresponding to the token IDs, if available.
     """
 
-    multi_modal_data: NotRequired[Optional["MultiModalData"]]
+    multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
     """
     Optional multi-modal data to pass to the model,
     if the model supports it.

+ 4 - 3
aphrodite/inputs/registry.py

@@ -12,7 +12,7 @@ from .data import LLMInputs
 if TYPE_CHECKING:
     from aphrodite.common.config import ModelConfig, VisionLanguageConfig
     from aphrodite.common.sequence import SequenceData
-    from aphrodite.multimodal import MultiModalData
+    from aphrodite.multimodal import MultiModalDataDict
 
 C = TypeVar("C", bound=PretrainedConfig)
 
@@ -61,7 +61,8 @@ class InputContext:
 N = TypeVar("N", bound=Type[nn.Module])
 
 DummyDataFactory = Callable[[InputContext, int],
-                            Tuple["SequenceData", Optional["MultiModalData"]]]
+                            Tuple["SequenceData",
+                                  Optional["MultiModalDataDict"]]]
 """
 Create dummy data to be inputted into the model.
 Note:
@@ -88,7 +89,7 @@ class InputRegistry:
         self,
         ctx: InputContext,
         seq_len: int,
-    ) -> Tuple["SequenceData", Optional["MultiModalData"]]:
+    ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
         """
         The default dummy data factory represents the longest possible text
         that can be inputted to the model.

+ 2 - 3
aphrodite/modeling/model_loader/loader.py

@@ -80,9 +80,8 @@ def _get_model_initialization_kwargs(
 
     if supports_vision(model_class):
         if vlm_config is None:
-            raise ValueError("Provide `image_input_type` and other vision "
-                             "related configurations through LLM entrypoint "
-                             "or engine arguments.")
+            raise ValueError("Provide vision related configurations "
+                             "through LLM entrypoint or engine arguments.")
 
         extra_kwargs["vlm_config"] = vlm_config
 

+ 2 - 18
aphrodite/modeling/models/clip.py

@@ -12,7 +12,6 @@ from aphrodite.common.sequence import SequenceData
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
-from aphrodite.multimodal.image import ImageFeatureData, ImagePixelData
 from aphrodite.quantization import QuantizationConfig
 
 
@@ -49,7 +48,7 @@ def dummy_seq_data_for_clip(
     return SequenceData(token_ids)
 
 
-def dummy_pixel_data_for_clip(
+def dummy_image_for_clip(
     hf_config: CLIPVisionConfig,
     *,
     image_width_override: Optional[int] = None,
@@ -62,22 +61,7 @@ def dummy_pixel_data_for_clip(
         height = image_height_override
 
     image = Image.new("RGB", (width, height), color=0)
-    return ImagePixelData(image)
-
-
-def dummy_feature_data_for_clip(
-    hf_config: CLIPVisionConfig,
-    *,
-    image_feature_size_override: Optional[int] = None,
-):
-    if image_feature_size_override is None:
-        image_feature_size = get_clip_image_feature_size(hf_config)
-    else:
-        image_feature_size = image_feature_size_override
-
-    values = torch.zeros((1, image_feature_size, hf_config.hidden_size),
-                         dtype=torch.float16)
-    return ImageFeatureData(values)
+    return {"image": image}
 
 
 # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa

+ 20 - 83
aphrodite/modeling/models/llava.py

@@ -1,4 +1,4 @@
-from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
+from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
 
 import torch
 import torch.nn as nn
@@ -16,11 +16,10 @@ from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.clip import CLIPVisionModel
 from aphrodite.modeling.models.llama import LlamaModel
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.multimodal import MULTIMODAL_REGISTRY, MultiModalData
+from aphrodite.multimodal import MULTIMODAL_REGISTRY
 from aphrodite.common.sequence import SamplerOutput
 
-from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip,
-                   dummy_seq_data_for_clip)
+from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
 from .interfaces import SupportsVision
 
 _KEYS_TO_MODIFY_MAPPING = {
@@ -29,7 +28,7 @@ _KEYS_TO_MODIFY_MAPPING = {
 }
 
 
-# TODO(xwjiang): Run benchmark and decide if TP.
+# TODO: Run benchmark and decide if TP.
 class LlavaMultiModalProjector(nn.Module):
 
     def __init__(self, vision_hidden_size: int, text_hidden_size: int,
@@ -75,17 +74,10 @@ class LlavaImagePixelInputs(TypedDict):
     """Shape: (batch_size, num_channels, height, width)"""
 
 
-class LlavaImageFeatureInputs(TypedDict):
-    type: Literal["image_features"]
-    data: torch.Tensor
-    """Shape: (batch_size, image_feature_size, hidden_size)"""
-
-
-LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
+LlavaImageInputs = LlavaImagePixelInputs
 
 
 def dummy_data_for_llava(ctx: InputContext, seq_len: int):
-    multimodal_config = ctx.get_multimodal_config()
     hf_config = ctx.get_hf_config(LlavaConfig)
     vision_config = hf_config.vision_config
 
@@ -96,13 +88,7 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int):
             image_token_id=hf_config.image_token_index,
         )
 
-        image_input_type = multimodal_config.image_input_type
-        ImageInputType = VisionLanguageConfig.ImageInputType
-        mm_data: MultiModalData
-        if image_input_type == ImageInputType.PIXEL_VALUES:
-            mm_data = dummy_pixel_data_for_clip(vision_config)
-        elif image_input_type == ImageInputType.IMAGE_FEATURES:
-            mm_data = dummy_feature_data_for_clip(vision_config)
+        mm_data = dummy_image_for_clip(vision_config)
 
         return seq_data, mm_data
 
@@ -110,8 +96,7 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int):
     raise NotImplementedError(msg)
 
 
-@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
-@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
+@MULTIMODAL_REGISTRY.register_image_input_mapper()
 @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
 class LlavaForConditionalGeneration(nn.Module, SupportsVision):
 
@@ -125,11 +110,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
         self.config = config
         self.vlm_config = vlm_config
 
-        if self.vlm_config.image_input_type == (
-                VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
-            self.vision_tower = CLIPVisionModel(config.vision_config)
-        else:
-            self.vision_tower = None
+        # TODO: Optionally initializes this for supporting embeddings.
+        self.vision_tower = CLIPVisionModel(config.vision_config)
 
         self.multi_modal_projector = LlavaMultiModalProjector(
             vision_hidden_size=config.vision_config.hidden_size,
@@ -164,44 +146,18 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
     def _parse_and_validate_image_input(
             self, **kwargs: object) -> Optional[LlavaImageInputs]:
         pixel_values = kwargs.pop("pixel_values", None)
-        image_features = kwargs.pop("image_features", None)
-
-        expected_input_type = self.vlm_config.image_input_type
-        ImageInputType = VisionLanguageConfig.ImageInputType
-
-        if expected_input_type == ImageInputType.PIXEL_VALUES:
-            if image_features is not None:
-                raise ValueError(
-                    "Expected pixel values but got image features")
-            if pixel_values is None:
-                return None
-
-            if not isinstance(pixel_values, torch.Tensor):
-                raise ValueError("Incorrect type of pixel values. "
-                                 f"Got type: {type(pixel_values)}")
 
-            return LlavaImagePixelInputs(
-                type="pixel_values",
-                data=self._validate_image_data(pixel_values),
-            )
+        if pixel_values is None:
+            return None
 
-        if expected_input_type == ImageInputType.IMAGE_FEATURES:
-            if pixel_values is not None:
-                raise ValueError(
-                    "Expected image features but got pixel values")
-            if image_features is None:
-                return None
+        if not isinstance(pixel_values, torch.Tensor):
+            raise ValueError("Incorrect type of pixel values. "
+                             f"Got type: {type(pixel_values)}")
 
-            if not isinstance(image_features, torch.Tensor):
-                raise ValueError("Incorrect type of image features. "
-                                 f"Got type: {type(image_features)}")
-
-            return LlavaImageFeatureInputs(
-                type="image_features",
-                data=self._validate_image_data(image_features),
-            )
-
-        return None
+        return LlavaImagePixelInputs(
+            type="pixel_values",
+            data=self._validate_image_data(pixel_values),
+        )
 
     def _select_image_features(self, image_features: torch.Tensor, *,
                                strategy: str) -> torch.Tensor:
@@ -236,12 +192,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
 
     def _process_image_input(self,
                              image_input: LlavaImageInputs) -> torch.Tensor:
-        if image_input["type"] == "pixel_values":
-            assert self.vision_tower is not None
-            image_features = self._process_image_pixels(image_input)
-        else:
-            image_features = image_input["data"]
-
+        assert self.vision_tower is not None
+        image_features = self._process_image_pixels(image_input)
         return self.multi_modal_projector(image_features)
 
     def forward(
@@ -272,25 +224,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
         This way, the `positions` and `attn_metadata` are consistent
         with the `input_ids`.
 
-        This model has two modes of image inputs:
-        `PIXEL_VALUES` and `IMAGE_FEATURES`.
-
         Args:
             input_ids: Flattened (concatenated) input_ids corresponding to a
                 batch.
             pixel_values: The pixels in each input image.
-                Expects a batch with shape `[1, 3, 336, 336]`.
-                (Only applicable to `PIXEL_VALUES` mode)
-            image_features: The image features for each input image outputted by
-                the vision tower before passing to the multi-modal projector.
-                Expects a batch with shape `[1, 576, 1024]`.
-                (Only applicable to `IMAGE_FEATURES` mode)
-
-        See also:
-            Each input maps to huggingface implementation, as follows:
-
-            - `pixel_values`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava/modeling_llava.py#L360
-            - `image_features`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava/modeling_llava.py#L437
         """
         image_input = self._parse_and_validate_image_input(**kwargs)
 

+ 47 - 95
aphrodite/modeling/models/llava_next.py

@@ -1,9 +1,9 @@
-from typing import (Dict, Iterable, List, Literal, Optional, Tuple, TypedDict,
-                    Union)
+from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict
 
 import torch
 import torch.nn as nn
 from loguru import logger
+from PIL import Image
 from transformers import CLIPVisionConfig, LlavaNextConfig
 from transformers.models.llava_next.modeling_llava_next import (
     get_anyres_image_grid_shape, unpad_image)
@@ -17,20 +17,17 @@ from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
-from aphrodite.modeling.models.clip import (CLIPVisionModel,
-                                            dummy_feature_data_for_clip,
-                                            dummy_pixel_data_for_clip,
-                                            dummy_seq_data_for_clip,
-                                            get_clip_patch_grid_length)
-from aphrodite.modeling.models.interfaces import SupportsVision
+from aphrodite.modeling.models.clip import CLIPVisionModel
 from aphrodite.modeling.models.llama import LlamaModel
-from aphrodite.modeling.models.llava import (LlavaMultiModalProjector,
-                                             merge_vision_embeddings)
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.multimodal import MULTIMODAL_REGISTRY, MultiModalData
-from aphrodite.multimodal.image import ImagePixelData
+from aphrodite.multimodal import MULTIMODAL_REGISTRY
 from aphrodite.quantization.base_config import QuantizationConfig
 
+from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
+                   get_clip_patch_grid_length)
+from .interfaces import SupportsVision
+from .llava import LlavaMultiModalProjector, merge_vision_embeddings
+
 _KEYS_TO_MODIFY_MAPPING = {
     "language_model.lm_head": "lm_head",
     "language_model.model": "language_model",
@@ -46,17 +43,7 @@ class LlavaNextImagePixelInputs(TypedDict):
     """Shape: (batch_size, 2)"""
 
 
-class LlavaNextImageFeatureInputs(TypedDict):
-    type: Literal["image_features"]
-    data: torch.Tensor
-    """Shape: (batch_size, 1 + num_patches, image_feature_size, hidden_size)"""
-
-    image_sizes: NotRequired[torch.Tensor]
-    """Shape: (batch_size, 2)"""
-
-
-LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
-                             LlavaNextImageFeatureInputs]
+LlavaNextImageInputs = LlavaNextImagePixelInputs
 
 
 def _get_llava_next_num_unpadded_features(
@@ -137,20 +124,11 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
             image_feature_size_override=image_feature_size,
         )
 
-        image_input_type = multimodal_config.image_input_type
-        ImageInputType = VisionLanguageConfig.ImageInputType
-        mm_data: MultiModalData
-        if image_input_type == ImageInputType.PIXEL_VALUES:
-            mm_data = dummy_pixel_data_for_clip(
-                vision_config,
-                image_width_override=dummy_width,
-                image_height_override=dummy_height,
-            )
-        elif image_input_type == ImageInputType.IMAGE_FEATURES:
-            mm_data = dummy_feature_data_for_clip(
-                vision_config,
-                image_feature_size_override=image_feature_size,
-            )
+        mm_data = dummy_image_for_clip(
+            vision_config,
+            image_width_override=dummy_width,
+            image_height_override=dummy_height,
+        )
 
         return seq_data, mm_data
 
@@ -158,31 +136,26 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
     raise NotImplementedError(msg)
 
 
-def _pixel_mapper(ctx: InputContext,
-                  data: ImagePixelData) -> Dict[str, torch.Tensor]:
-    image = data.image
+def _pixel_mapper(ctx: InputContext, image: object) -> Dict[str, torch.Tensor]:
 
-    if isinstance(image, torch.Tensor):
-        pixel_values = image.to(ctx.model_config.dtype)
-        batch_size, _, _, h, w = pixel_values.shape
-        image_sizes = torch.tensor([(w, h) for _ in range(batch_size)])
+    if isinstance(image, Image.Image):
 
-        return {"pixel_values": pixel_values, "image_sizes": image_sizes}
+        # Temporary patch before dynamic number of image tokens is supported
+        _, _, h, w = ctx.get_multimodal_config().image_input_shape
+        if (w, h) != (image.width, image.height):
+            logger.warning(
+                "Dynamic image shape is currently not supported. "
+                "Resizing input image to (%d, %d).", w, h)
 
-    # Temporary patch before dynamic number of image tokens is supported
-    _, _, h, w = ctx.get_multimodal_config().image_input_shape
-    if (w, h) != (image.width, image.height):
-        logger.warning("Dynamic image shape is currently not supported. "
-                       f"Resizing input image to ({w}, {h}).")
+            image = image.resize((w, h))
 
-        data.image = image.resize((w, h))
+        return MULTIMODAL_REGISTRY._get_plugin("image") \
+            ._default_input_mapper(ctx, image)
 
-    return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \
-        ._default_input_mapper(ctx, data)
+    raise TypeError(f"Invalid type for 'image': {type(image)}")
 
 
-@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
-@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper(_pixel_mapper)
+@MULTIMODAL_REGISTRY.register_image_input_mapper(_pixel_mapper)
 @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
 class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
 
@@ -196,11 +169,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
         self.config = config
         self.vlm_config = vlm_config
 
-        if self.vlm_config.image_input_type == (
-                VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
-            self.vision_tower = CLIPVisionModel(config=config.vision_config)
-        else:
-            raise TypeError("Image features are not supported by LLaVA-NeXT")
+        self.vision_tower = CLIPVisionModel(config=config.vision_config)
 
         self.multi_modal_projector = LlavaMultiModalProjector(
             vision_hidden_size=config.vision_config.hidden_size,
@@ -226,9 +195,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
     def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor:
         _, num_channels, _, _ = self.vlm_config.image_input_shape
 
-        # Note that this is different from that of Aphrodite
-        # vision_language_config since the image is resized by the HuggingFace
-        # preprocessor
+        # Note that this is different from that of vLLM vision_language_config
+        # since the image is resized by the HuggingFace preprocessor
         height = width = self.config.vision_config.image_size
 
         if list(data.shape[2:]) != [num_channels, height, width]:
@@ -236,7 +204,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
                 f"The expected image tensor shape is batch dimension plus "
                 f"num_patches plus {[num_channels, height, width]}. "
                 f"You supplied {data.shape}. "
-                f"If you are using Aphrodite's entrypoint, make sure your "
+                f"If you are using vLLM's entrypoint, make sure your "
                 f"supplied image input is consistent with "
                 f"image_input_shape in engine args.")
 
@@ -254,36 +222,23 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
             self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
         pixel_values = kwargs.pop("pixel_values", None)
         image_sizes = kwargs.pop("image_sizes", None)
-        image_features = kwargs.pop("image_features", None)
-
-        expected_input_type = self.vlm_config.image_input_type
-        ImageInputType = VisionLanguageConfig.ImageInputType
-
-        if expected_input_type == ImageInputType.PIXEL_VALUES:
-            if image_features is not None:
-                raise ValueError(
-                    "Expected pixel values but got image features")
-            if pixel_values is None:
-                return None
 
-            if not isinstance(pixel_values, torch.Tensor):
-                raise ValueError("Incorrect type of pixel values. "
-                                 f"Got type: {type(pixel_values)}")
+        if pixel_values is None or image_sizes is None:
+            return None
 
-            if not isinstance(image_sizes, torch.Tensor):
-                raise ValueError("Incorrect type of image sizes. "
-                                 f"Got type: {type(image_sizes)}")
+        if not isinstance(pixel_values, torch.Tensor):
+            raise ValueError("Incorrect type of pixel values. "
+                             f"Got type: {type(pixel_values)}")
 
-            return LlavaNextImagePixelInputs(
-                type="pixel_values",
-                data=self._validate_image_pixels(pixel_values),
-                image_sizes=self._validate_image_sizes(image_sizes),
-            )
+        if not isinstance(image_sizes, torch.Tensor):
+            raise ValueError("Incorrect type of image sizes. "
+                             f"Got type: {type(image_sizes)}")
 
-        assert expected_input_type != ImageInputType.IMAGE_FEATURES, (
-            "Failed to validate this at initialization time")
-
-        return None
+        return LlavaNextImagePixelInputs(
+            type="pixel_values",
+            data=self._validate_image_pixels(pixel_values),
+            image_sizes=self._validate_image_sizes(image_sizes),
+        )
 
     def _select_image_features(self, image_features: torch.Tensor, *,
                                strategy: str) -> torch.Tensor:
@@ -390,11 +345,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
 
     def _process_image_input(
             self, image_input: LlavaNextImageInputs) -> torch.Tensor:
-        if image_input["type"] == "pixel_values":
-            assert self.vision_tower is not None
-            image_features = self._process_image_pixels(image_input)
-        else:
-            image_features = image_input["data"]
+        assert self.vision_tower is not None
+        image_features = self._process_image_pixels(image_input)
 
         patch_embeddings = self.multi_modal_projector(image_features)
 

+ 9 - 17
aphrodite/modeling/models/phi3v.py

@@ -1,4 +1,5 @@
 # coding=utf-8
+# Copyright 2024 The PygmalionAI team.
 # Copyright 2024 The vLLM team.
 # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
 #
@@ -34,10 +35,9 @@ from aphrodite.modeling.models.clip import CLIPVisionModel
 from aphrodite.modeling.models.llama import LlamaModel
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import MULTIMODAL_REGISTRY
-from aphrodite.multimodal.image import ImagePixelData
 from aphrodite.quantization.base_config import QuantizationConfig
 
-from .clip import dummy_pixel_data_for_clip, dummy_seq_data_for_clip
+from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
 from .interfaces import SupportsVision
 
 _KEYS_TO_MODIFY_MAPPING = {
@@ -283,7 +283,7 @@ def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
         image_token_id=32044,
         image_feature_size_override=image_feature_size,
     )
-    mm_data = dummy_pixel_data_for_clip(
+    mm_data = dummy_image_for_clip(
         CLIP_VIT_LARGE_PATCH14_336_CONFIG,
         image_width_override=dummy_width,
         image_height_override=dummy_height,
@@ -328,8 +328,7 @@ def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
 
 
 def _image_processor(ctx: InputContext,
-                     data: ImagePixelData) -> Dict[str, torch.Tensor]:
-    image = data.image
+                     image: object) -> Dict[str, torch.Tensor]:
 
     if isinstance(image, Image.Image):
         # Temporary patch before dynamic number of image tokens is supported
@@ -339,13 +338,14 @@ def _image_processor(ctx: InputContext,
             logger.warning("Dynamic image shape is currently not supported. "
                            f"Resizing input image to ({w}, {h}).")
 
-            data.image = image.resize((w, h))
+            image = image.resize((w, h))
 
-    return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \
-            ._default_input_mapper(ctx, data)
+        return MULTIMODAL_REGISTRY._get_plugin("image") \
+                ._default_input_mapper(ctx, image)
+    raise TypeError(f"Invalid type for 'image': {type(image)}")
 
 
-@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper(_image_processor)
+@MULTIMODAL_REGISTRY.register_image_input_mapper(_image_processor)
 @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
 class Phi3VForCausalLM(nn.Module, SupportsVision):
 
@@ -371,14 +371,6 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
         pixel_values = kwargs.pop("pixel_values", None)
         image_sizes = kwargs.pop("image_sizes", None)
 
-        expected_input_type = self.vlm_config.image_input_type
-        ImageInputType = VisionLanguageConfig.ImageInputType
-
-        if expected_input_type != ImageInputType.PIXEL_VALUES:
-            raise ValueError(
-                f"Unexpected image input type: {expected_input_type}."
-                "Phi3v only support pixel_values input currently.")
-
         if pixel_values is not None and image_sizes is not None:
             return Phi3VImagePixelInputs(type="pixel_values",
                                          data=pixel_values,

+ 6 - 4
aphrodite/multimodal/__init__.py

@@ -1,4 +1,4 @@
-from .base import MultiModalData, MultiModalPlugin
+from .base import MultiModalDataDict, MultiModalPlugin
 from .registry import MultiModalRegistry
 
 MULTIMODAL_REGISTRY = MultiModalRegistry()
@@ -10,6 +10,8 @@ See also:
 """
 
 __all__ = [
-    "MultiModalData", "MultiModalPlugin", "MULTIMODAL_REGISTRY",
-    "MultiModalRegistry"
-]
+    "MultiModalPlugin",
+    "MULTIMODAL_REGISTRY",
+    "MultiModalRegistry",
+    "MultiModalDataDict",
+]

+ 28 - 30
aphrodite/multimodal/base.py

@@ -1,6 +1,6 @@
 from abc import ABC, abstractmethod
-from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Type,
-                    TypeVar)
+from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Type,
+                    TypedDict, TypeVar, Union)
 
 from loguru import logger
 
@@ -9,36 +9,33 @@ from aphrodite.inputs import InputContext
 
 if TYPE_CHECKING:
     import torch
+    from PIL import Image
     from torch import nn
 
+N = TypeVar("N", bound=Type["nn.Module"])
 
-class MultiModalData:
-    """
-    Base class that contains multi-modal data.
-
-    To add a new modality, add a new file under ``multimodal`` directory.
 
-    In this new file, subclass :class:`~MultiModalData` and
-    :class:`~MultiModalPlugin`.
+class MultiModalDataBuiltins(TypedDict, total=False):
+    image: "Image.Image"
 
-    Finally, register the new plugin to
-    :const:`aphrodite.multimodal.MULTIMODAL_REGISTRY`.
-    This enables models to call :meth:`MultiModalRegistry.map_input` for
-    the new modality.
-    """
-    pass
 
+MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]
+"""
+A dictionary containing an item for each modality type to input.
 
-D = TypeVar("D", bound=MultiModalData)
-N = TypeVar("N", bound=Type["nn.Module"])
+The data belonging to each modality is converted into keyword arguments 
+to the model by the corresponding mapper. By default, the mapper of 
+the corresponding plugin with the same modality key is applied.
+"""
 
-MultiModalInputMapper = Callable[[InputContext, D], Dict[str, "torch.Tensor"]]
+MultiModalInputMapper = Callable[[InputContext, object], Dict[str,
+                                                              "torch.Tensor"]]
 """Return a dictionary to be passed as keyword arguments to
 :meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
 and processors in HuggingFace Transformers."""
 
 
-class MultiModalPlugin(ABC, Generic[D]):
+class MultiModalPlugin(ABC):
     """
     Base class that defines data processing logic for a specific modality.
 
@@ -51,19 +48,18 @@ class MultiModalPlugin(ABC, Generic[D]):
 
     def __init__(self) -> None:
         self._input_mappers: Dict[Type["nn.Module"],
-                                  MultiModalInputMapper[D]] = {}
+                                  MultiModalInputMapper] = {}
 
     @abstractmethod
-    def get_data_type(self) -> Type[D]:
+    def get_data_key(self) -> str:
         """
-        Get the modality (subclass of :class:`~MultiModalData`) served by
-        this plugin.
+        Get the data key corresponding to the modality.
         """
         raise NotImplementedError
 
     @abstractmethod
     def _default_input_mapper(self, ctx: InputContext,
-                              data: D) -> Dict[str, "torch.Tensor"]:
+                              data: object) -> Dict[str, "torch.Tensor"]:
         """Return a dictionary to be passed as keyword arguments to
         :meth:`~torch.nn.Module.forward`. This is similar in concept to
         tokenizers and processors in HuggingFace Transformers.
@@ -72,11 +68,10 @@ class MultiModalPlugin(ABC, Generic[D]):
 
     def register_input_mapper(
         self,
-        mapper: Optional[MultiModalInputMapper[D]] = None,
+        mapper: Optional[MultiModalInputMapper] = None,
     ):
         """
         Register an input mapper to a model class.
-        
         When the model receives input data that matches the modality served by
         this plugin (see :meth:`get_data_type`), the provided function is
         invoked to transform the data into a dictionary of model inputs.
@@ -89,8 +84,9 @@ class MultiModalPlugin(ABC, Generic[D]):
         def wrapper(model_cls: N) -> N:
             if model_cls in self._input_mappers:
                 logger.warning(
-                    f"Model class {model_cls} already has an input mapper "
-                    f"registered to {self}. It is overwritten by the new one.")
+                    "Model class %s already has an input mapper "
+                    "registered to %s. It is overwritten by the new one.",
+                    model_cls, self)
 
             self._input_mappers[model_cls] = mapper \
                 or self._default_input_mapper
@@ -100,11 +96,13 @@ class MultiModalPlugin(ABC, Generic[D]):
         return wrapper
 
     def map_input(self, model_config: ModelConfig,
-                  data: D) -> Dict[str, "torch.Tensor"]:
+                  data: object) -> Dict[str, "torch.Tensor"]:
         """
-        Apply an input mapper to a :class:`~MultiModalData` instance passed
+        Apply an input mapper to a data passed
         to the model, transforming the data into a dictionary of model inputs.
 
+        If the data is not something that the mapper expects, throws TypeError.
+
         The model is identified by ``model_config``.
 
         TODO: Add guide [ref: PR #5276]

+ 12 - 81
aphrodite/multimodal/image.py

@@ -1,5 +1,5 @@
 from functools import lru_cache
-from typing import Dict, Type, Union
+from typing import Dict
 
 import torch
 from loguru import logger
@@ -9,103 +9,34 @@ from aphrodite.common.config import ModelConfig
 from aphrodite.inputs.registry import InputContext
 from aphrodite.transformers_utils.image_processor import get_image_processor
 
-from .base import MultiModalData, MultiModalPlugin
+from .base import MultiModalPlugin
 
 cached_get_image_processor = lru_cache(get_image_processor)
 
 
-class ImagePixelData(MultiModalData):
-    """
-    The pixel data of an image. Can be one of:
+class ImagePlugin(MultiModalPlugin):
 
-    - :class:`PIL.Image.Image`: An image object. Requires that a HuggingFace
-      processor is available to the model.
-    - :class:`torch.Tensor`: The raw pixel data which is passed to the model
-      without additional pre-processing.
-    """
-
-    def __init__(self, image: Union[Image.Image, torch.Tensor]) -> None:
-        if isinstance(image, Image.Image):
-            # So that this class can be created inside the Image context manager
-            image.load()
-
-        self.image = image
-
-    def __repr__(self) -> str:
-        image = self.image
-        if isinstance(image, Image.Image):
-            return f"{type(self).__name__}(image={image})"
-
-        return (f"{type(self).__name__}(image=torch.Tensor(shape="
-                f"{image.shape}, dtype={image.dtype}))")
-
-
-class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):
-
-    def get_data_type(self) -> Type[ImagePixelData]:
-        return ImagePixelData
+    def get_data_key(self) -> str:
+        return "image"
 
     def _get_hf_image_processor(self, model_config: ModelConfig):
-        vlm_config = model_config.multimodal_config
-        if vlm_config is None or vlm_config.image_processor is None:
-            return None
-
         return cached_get_image_processor(
-            vlm_config.image_processor,
-            trust_remote_code=model_config.trust_remote_code,
-            revision=vlm_config.image_processor_revision,
-        )
+            model_config.model,
+            trust_remote_code=model_config.trust_remote_code)
 
     def _default_input_mapper(self, ctx: InputContext,
-                              data: ImagePixelData) -> Dict[str, torch.Tensor]:
+                              data: object) -> Dict[str, torch.Tensor]:
         model_config = ctx.model_config
-        image = data.image
-
-        if isinstance(image, Image.Image):
+        if isinstance(data, Image.Image):
             image_processor = self._get_hf_image_processor(model_config)
             if image_processor is None:
                 raise RuntimeError("No HuggingFace processor is available"
                                    "to process the image object")
             try:
-                return image_processor.preprocess(image, return_tensors="pt") \
+                return image_processor.preprocess(data, return_tensors="pt") \
                     .to(model_config.dtype).data
             except Exception:
-                logger.error(f"Failed to process image ({image})")
+                logger.error("Failed to process image (%s)", data)
                 raise
-        elif isinstance(image, torch.Tensor):
-            pixel_values = image.to(model_config.dtype)
-
-            return {"pixel_values": pixel_values}
-
-        raise TypeError(f"Invalid image type: {type(image)}")
-
-
-class ImageFeatureData(MultiModalData):
-    """
-    The feature vector of an image, passed directly to the model.
-
-    This should be the output of the vision tower.
-    """
-
-    def __init__(self, image_features: torch.Tensor) -> None:
-        self.image_features = image_features
-
-    def __repr__(self) -> str:
-        image_features = self.image_features
-
-        return (f"{type(self).__name__}(image_features=torch.Tensor(shape="
-                f"{image_features.shape}, dtype={image_features.dtype}))")
-
-
-class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]):
-
-    def get_data_type(self) -> Type[ImageFeatureData]:
-        return ImageFeatureData
-
-    def _default_input_mapper(
-            self, ctx: InputContext,
-            data: ImageFeatureData) -> Dict[str, torch.Tensor]:
-        model_config = ctx.model_config
-        image_features = data.image_features.to(model_config.dtype)
 
-        return {"image_features": image_features}
+        raise TypeError(f"Invalid type for 'image': {type(data)}")

+ 54 - 45
aphrodite/multimodal/registry.py

@@ -1,16 +1,14 @@
 import functools
-from typing import Any, Optional, Sequence, Type, TypeVar
+from typing import Optional, Sequence, Type, TypeVar
 
 from loguru import logger
 from torch import nn
 
 from aphrodite.common.config import ModelConfig
 
-from .base import MultiModalData, MultiModalInputMapper, MultiModalPlugin
-from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData,
-                    ImagePixelPlugin)
+from .base import MultiModalDataDict, MultiModalInputMapper, MultiModalPlugin
+from .image import ImagePlugin
 
-D = TypeVar("D", bound=MultiModalData)
 N = TypeVar("N", bound=Type[nn.Module])
 
 
@@ -18,80 +16,91 @@ class MultiModalRegistry:
     """
     A registry to dispatch data processing
     according to its modality and the target model.
+
+    The registry handles both external and internal data input.
     """
 
-    DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin())
+    DEFAULT_PLUGINS = (ImagePlugin(), )
 
     def __init__(
-        self,
-        *,
-        plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS,
-    ) -> None:
-        self._plugins_by_data_type = {p.get_data_type(): p for p in plugins}
+            self,
+            *,
+            plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
+        self._plugins = {p.get_data_key(): p for p in plugins}
 
-    def register_plugin(self, plugin: MultiModalPlugin[Any]) -> None:
-        data_type = plugin.get_data_type()
+    def register_plugin(self, plugin: MultiModalPlugin) -> None:
+        data_type_key = plugin.get_data_key()
 
-        if data_type in self._plugins_by_data_type:
+        if data_type_key in self._plugins:
             logger.warning(
-                f"A plugin is already registered for data type {data_type}, "
-                f"and will be overwritten by the new plugin {plugin}.")
+                "A plugin is already registered for data type %s, "
+                "and will be overwritten by the new plugin %s.", data_type_key,
+                plugin)
 
-        self._plugins_by_data_type[data_type] = plugin
+        self._plugins[data_type_key] = plugin
 
-    def _get_plugin_for_data_type(self, data_type: Type[MultiModalData]):
-        for typ in data_type.mro():
-            plugin = self._plugins_by_data_type.get(typ)
-            if plugin is not None:
-                return plugin
+    def _get_plugin(self, data_type_key: str):
+        plugin = self._plugins.get(data_type_key)
+        if plugin is not None:
+            return plugin
 
-        msg = f"Unknown multi-modal data type: {data_type}"
+        msg = f"Unknown multi-modal data type: {data_type_key}"
         raise NotImplementedError(msg)
 
-    def register_input_mapper(
+    def register_image_input_mapper(
         self,
-        data_type: Type[D],
-        mapper: Optional[MultiModalInputMapper[D]] = None,
+        mapper: Optional[MultiModalInputMapper] = None,
     ):
         """
-        Register an input mapper for a specific modality to a model class.
+        Register an input mapper for image data to a model class.
 
         See :meth:`MultiModalPlugin.register_input_mapper` for more details.
         """
-        return self._get_plugin_for_data_type(data_type) \
-            .register_input_mapper(mapper)
+        return self.register_input_mapper("image", mapper)
+
+    def _process_input(self, key: str, value: object,
+                       model_config: ModelConfig):
+        plugin = self._plugins.get(key)
+        if plugin:
+            return plugin.map_input(model_config, value)
+        msg = f"Unknown multi-modal data type: {key}"
+        raise NotImplementedError(msg)
 
-    def register_image_pixel_input_mapper(
+    def register_input_mapper(
         self,
-        mapper: Optional[MultiModalInputMapper[ImagePixelData]] = None,
+        data_type: str,
+        mapper: Optional[MultiModalInputMapper] = None,
     ):
         """
-        Register an input mapper for image pixel data to a model class.
+        Register an input mapper for a specific modality to a model class.
 
         See :meth:`MultiModalPlugin.register_input_mapper` for more details.
         """
-        return self.register_input_mapper(ImagePixelData, mapper)
-
-    def register_image_feature_input_mapper(
-        self,
-        mapper: Optional[MultiModalInputMapper[ImageFeatureData]] = None,
-    ):
+        plugin = self._plugins.get(data_type)
+        if not plugin:
+            msg = f"Unknown multi-modal data type: {data_type}"
+            raise NotImplementedError(msg)
+        return plugin.register_input_mapper(mapper)
+
+    def register_image_input(self,
+                             mapper: Optional[MultiModalInputMapper] = None):
         """
-        Register an input mapper for image feature data to a model class.
+        Register an input mapper for image pixel data to a model class.
 
         See :meth:`MultiModalPlugin.register_input_mapper` for more details.
         """
-        return self.register_input_mapper(ImageFeatureData, mapper)
+        return self.register_input_mapper("image", mapper)
 
-    def map_input(self, model_config: ModelConfig, data: MultiModalData):
+    def map_input(self, model_config: ModelConfig, data: MultiModalDataDict):
         """
-        Apply an input mapper to a :class:`~MultiModalData` instance passed
-        to the model.
+        Apply an input mapper to the data passed to the model.
         
         See :meth:`MultiModalPlugin.map_input` for more details.
         """
-        return self._get_plugin_for_data_type(type(data)) \
-            .map_input(model_config, data)
+        result_list = [
+            self._process_input(k, v, model_config) for k, v in data.items()
+        ]
+        return {k: v for d in result_list for k, v in d.items()}
 
     def create_input_mapper(self, model_config: ModelConfig):
         """

+ 97 - 0
aphrodite/multimodal/utils.py

@@ -0,0 +1,97 @@
+import base64
+import os
+from io import BytesIO
+from typing import Optional, Union
+from urllib.parse import urlparse
+
+import aiohttp
+from PIL import Image
+
+from aphrodite.common.config import ModelConfig
+from aphrodite.multimodal.base import MultiModalDataDict
+
+
+APHRODITE_IMAGE_FETCH_TIMEOUT = int(os.getenv("APHRODITE_IMAGE_FETCH_TIMEOUT",
+                                              10))
+
+class ImageFetchAiohttp:
+    aiohttp_client: Optional[aiohttp.ClientSession] = None
+
+    @classmethod
+    def get_aiohttp_client(cls) -> aiohttp.ClientSession:
+        if cls.aiohttp_client is None:
+            timeout = aiohttp.ClientTimeout(total=APHRODITE_IMAGE_FETCH_TIMEOUT)
+            connector = aiohttp.TCPConnector()
+            cls.aiohttp_client = aiohttp.ClientSession(timeout=timeout,
+                                                       connector=connector)
+
+        return cls.aiohttp_client
+
+    @classmethod
+    async def fetch_image(cls, image_url: str) -> Image.Image:
+        """Load PIL image from a url or base64 encoded openai GPT4V format"""
+
+        if image_url.startswith('http'):
+            parsed_url = urlparse(image_url)
+            if parsed_url.scheme not in ["http", "https"]:
+                raise ValueError("Invalid 'image_url': A valid 'image_url' "
+                                 "must have scheme 'http' or 'https'.")
+            # Avoid circular import
+            from aphrodite import __version__ as APHRODITE_VERSION
+
+            client = cls.get_aiohttp_client()
+            headers = {"User-Agent": f"aphrodite/{APHRODITE_VERSION}"}
+
+            async with client.get(url=image_url, headers=headers) as response:
+                response.raise_for_status()
+                image_raw = await response.read()
+            image = Image.open(BytesIO(image_raw))
+
+        # Only split once and assume the second part is the base64 encoded image
+        elif image_url.startswith('data:image'):
+            image = load_image_from_base64(image_url.split(',', 1)[1])
+
+        else:
+            raise ValueError(
+                "Invalid 'image_url': A valid 'image_url' must start "
+                "with either 'data:image' or 'http'.")
+
+        image.load()
+        return image
+
+
+def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str:
+    """Encode a pillow image to base64 format."""
+
+    buffered = BytesIO()
+    if format == 'JPEG':
+        image = image.convert('RGB')
+    image.save(buffered, format)
+    return base64.b64encode(buffered.getvalue()).decode('utf-8')
+
+
+def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
+    """Load image from base64 format."""
+    return Image.open(BytesIO(base64.b64decode(image)))
+
+
+# TODO: move this to a model registry for preprocessing vision
+# language prompts based on the model type.
+def get_full_image_text_prompt(image_prompt: str, text_prompt: str,
+                               config: ModelConfig) -> str:
+    """Combine image and text prompts for vision language model depending on
+    the model architecture."""
+
+    if config.hf_config.model_type in ("llava", "llava_next"):
+        full_prompt = f"{image_prompt}\n{text_prompt}"
+    elif config.hf_config.model_type == 'phi3_v':
+        full_prompt = f"{image_prompt}<s>\n{text_prompt}"
+    else:
+        raise ValueError(
+            f"Unsupported model type: {config.hf_config.model_type}")
+    return full_prompt
+
+
+async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
+    image = await ImageFetchAiohttp.fetch_image(image_url)
+    return {"image": image}

+ 1 - 1
aphrodite/task_handler/model_runner.py

@@ -523,7 +523,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
                      is not None else 1))
 
                 mm_data = seq_group_metadata.multi_modal_data
-                if mm_data is not None:
+                if mm_data:
                     # Process multi-modal data
                     mm_kwargs = self.multi_modal_input_mapper(mm_data)
                     for k, v in mm_kwargs.items():

+ 0 - 4
aphrodite/transformers_utils/image_processor.py

@@ -1,5 +1,3 @@
-from typing import Optional
-
 from transformers import AutoImageProcessor
 from transformers.image_processing_utils import BaseImageProcessor
 
@@ -8,7 +6,6 @@ def get_image_processor(
     processor_name: str,
     *args,
     trust_remote_code: bool = False,
-    revision: Optional[str] = None,
     **kwargs,
 ) -> BaseImageProcessor:
     """Gets an image processor for the given model name via HuggingFace."""
@@ -17,7 +14,6 @@ def get_image_processor(
             processor_name,
             *args,
             trust_remote_code=trust_remote_code,
-            revision=revision,
             **kwargs)
     except ValueError as e:
         # If the error pertains to the processor class not existing or not

+ 8 - 48
examples/vision/llava_example.py

@@ -1,38 +1,32 @@
-import argparse
 import os
 import subprocess
 
-import torch
 from PIL import Image
 
 from aphrodite import LLM
-from aphrodite.multimodal.image import ImageFeatureData, ImagePixelData
 
 # The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
 # You can use `.buildkite/download-images.sh` to download them
 
 
-def run_llava_pixel_values(*, disable_image_processor: bool = False):
+def run_llava():
     llm = LLM(
         model="llava-hf/llava-1.5-7b-hf",
-        image_input_type="pixel_values",
         image_token_id=32000,
         image_input_shape="1,3,336,336",
         image_feature_size=576,
-        disable_image_processor=disable_image_processor,
     )
 
     prompt = "<image>" * 576 + (
         "\nUSER: What is the content of this image?\nASSISTANT:")
 
-    if disable_image_processor:
-        image = torch.load("images/stop_sign_pixel_values.pt")
-    else:
-        image = Image.open("images/stop_sign.jpg")
+    image = Image.open("images/stop_sign.jpg")
 
     outputs = llm.generate({
         "prompt": prompt,
-        "multi_modal_data": ImagePixelData(image),
+        "multi_modal_data": {
+            "image": image
+        },
     })
 
     for o in outputs:
@@ -40,45 +34,11 @@ def run_llava_pixel_values(*, disable_image_processor: bool = False):
         print(generated_text)
 
 
-def run_llava_image_features():
-    llm = LLM(
-        model="llava-hf/llava-1.5-7b-hf",
-        image_input_type="image_features",
-        image_token_id=32000,
-        image_input_shape="1,576,1024",
-        image_feature_size=576,
-    )
-
-    prompt = "<image>" * 576 + (
-        "\nUSER: What is the content of this image?\nASSISTANT:")
-
-    image: torch.Tensor = torch.load("images/stop_sign_image_features.pt")
-
-    outputs = llm.generate({
-        "prompt": prompt,
-        "multi_modal_data": ImageFeatureData(image),
-    })
-
-    for o in outputs:
-        generated_text = o.outputs[0].text
-        print(generated_text)
-
-
-def main(args):
-    if args.type == "pixel_values":
-        run_llava_pixel_values()
-    else:
-        run_llava_image_features()
+def main():
+    run_llava()
 
 
 if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description="Demo on Llava")
-    parser.add_argument("--type",
-                        type=str,
-                        choices=["pixel_values", "image_features"],
-                        default="pixel_values",
-                        help="image input type")
-    args = parser.parse_args()
     # Download from s3
     s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
     local_directory = "images"
@@ -95,4 +55,4 @@ if __name__ == "__main__":
         local_directory,
         "--no-sign-request",
     ])
-    main(args)
+    main()

+ 35 - 26
examples/vision/llava_next_example.py

@@ -4,33 +4,42 @@ import requests
 from PIL import Image
 
 from aphrodite import LLM, SamplingParams
-from aphrodite.multimodal.image import ImagePixelData
 
 # Dynamic image input is currently not supported and therefore
 # a fixed image input shape and its corresponding feature size is required.
 
-llm = LLM(
-    model="llava-hf/llava-v1.6-mistral-7b-hf",
-    image_input_type="pixel_values",
-    image_token_id=32000,
-    image_input_shape="1,3,336,336",
-    image_feature_size=1176,
-)
-
-prompt = "[INST] " + "<image>" * 1176 + "\nWhat is shown in this image? [/INST]"
-url = "https://h2o-release.s3.amazonaws.com/h2ogpt/bigben.jpg"
-image = Image.open(BytesIO(requests.get(url).content))
-sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=100)
-
-outputs = llm.generate(
-    {
-        "prompt": prompt,
-        "multi_modal_data": ImagePixelData(image),
-    },
-    sampling_params=sampling_params)
-
-generated_text = ""
-for o in outputs:
-    generated_text += o.outputs[0].text
-
-print(f"LLM output:{generated_text}")
+
+def run_llava_next():
+    llm = LLM(
+        model="llava-hf/llava-v1.6-mistral-7b-hf",
+        image_token_id=32000,
+        image_input_shape="1,3,336,336",
+        image_feature_size=1176,
+    )
+
+    prompt = "[INST] " + "<image>" * 1176 + (
+        "\nWhat is shown in this image? [/INST]")
+    url = "https://h2o-release.s3.amazonaws.com/h2ogpt/bigben.jpg"
+    image = Image.open(BytesIO(requests.get(url).content))
+    sampling_params = SamplingParams(temperature=0.8,
+                                     top_p=0.95,
+                                     max_tokens=100)
+
+    outputs = llm.generate(
+        {
+            "prompt": prompt,
+            "multi_modal_data": {
+                "image": image
+            }
+        },
+        sampling_params=sampling_params)
+
+    generated_text = ""
+    for o in outputs:
+        generated_text += o.outputs[0].text
+
+    print(f"LLM output:{generated_text}")
+
+
+if __name__ == "__main__":
+    run_llava_next()

+ 8 - 6
examples/vision/phi3v_example.py

@@ -4,22 +4,22 @@ import subprocess
 from PIL import Image
 
 from aphrodite import LLM, SamplingParams
-from aphrodite.multimodal.image import ImagePixelData
 
 
 def run_phi3v():
     model_path = "microsoft/Phi-3-vision-128k-instruct"
 
-    # Note: The model has 128k context length by default which may cause OOM
-    # In this example, we override max_model_len to 2048.
+    # Note: The default setting of max_num_seqs (256) and
+    # max_model_len (128k) for this model may cause OOM.
+    # In this example, we override max_num_seqs to 5 while
+    # keeping the original context length of 128k.
     llm = LLM(
         model=model_path,
         trust_remote_code=True,
-        image_input_type="pixel_values",
         image_token_id=32044,
         image_input_shape="1,3,1008,1344",
         image_feature_size=1921,
-        max_model_len=2048,
+        max_num_seqs=5,
     )
 
     image = Image.open("images/cherry_blossom.jpg")
@@ -33,7 +33,9 @@ def run_phi3v():
     outputs = llm.generate(
         {
             "prompt": prompt,
-            "multi_modal_data": ImagePixelData(image),
+            "multi_modal_data": {
+                "image": image
+            },
         },
         sampling_params=sampling_params)
     for o in outputs: