فهرست منبع

feat: dynamic image size support for VLMs

AlpinDale 6 ماه پیش
والد
کامیت
4599c98f99

+ 1 - 1
aphrodite/common/sequence.py

@@ -457,7 +457,7 @@ class SequenceGroup:
         return next(iter(self.seqs_dict.values())).prompt_token_ids
 
     @property
-    def multi_modal_data(self) -> Optional["MultiModalDataDict"]:
+    def multi_modal_data(self) -> "MultiModalDataDict":
         # All sequences in the group should have the same multi-modal data.
         # We use the multi-modal data of an arbitrary sequence.
         return next(iter(self.seqs_dict.values())).multi_modal_data

+ 23 - 12
aphrodite/modeling/layers/rotary_embedding.py

@@ -485,7 +485,11 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
         return cache
 
 
-class Phi3LongRoPERotaryEmbedding(nn.Module):
+class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
+    """Phi3 family of models scaled rotary embedding.
+
+    Based on the original RotaryEmbedding implementation.
+    """
 
     def __init__(
         self,
@@ -505,11 +509,13 @@ class Phi3LongRoPERotaryEmbedding(nn.Module):
 
         if rotary_dim != head_size:
             raise ValueError(
-                f"Rotary dim must be equal to head size, got {rotary_dim} "
-                f"and {head_size}")
+                f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
+                    rotary_dim != head_size ({rotary_dim}!={head_size}).")
         if is_neox_style is False:
             raise ValueError(
-                "Phi3SuScaledRotaryEmbedding only supports Neox style")
+                "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
+            )
+
         self.head_size = head_size
         self.max_position_embeddings = max_position_embeddings
         self.original_max_position_embeddings = original_max_position_embeddings
@@ -536,8 +542,8 @@ class Phi3LongRoPERotaryEmbedding(nn.Module):
                              short_cache,
                              persistent=False)
 
-        long_cache = self._compute_cos_sin_cache(
-            original_max_position_embeddings, long_factor, long_mscale)
+        long_cache = self._compute_cos_sin_cache(max_position_embeddings,
+                                                 long_factor, long_mscale)
         long_cache = long_cache.to(dtype)
         self.register_buffer("long_cos_sin_cache",
                              long_cache,
@@ -555,9 +561,12 @@ class Phi3LongRoPERotaryEmbedding(nn.Module):
             0, self.head_size, 2, dtype=torch.float) / self.head_size)))
         return inv_freq
 
-    def _compute_cos_sin_cache(self, max_position_embeddings: int,
-                               rescale_factors: List[float],
-                               mscale: float) -> torch.Tensor:
+    def _compute_cos_sin_cache(
+        self,
+        max_position_embeddings: int,
+        rescale_factors: List[float],
+        mscale: float,
+    ) -> torch.Tensor:
         inv_freq = self._compute_inv_freq(rescale_factors)
         t = torch.arange(max_position_embeddings, dtype=torch.float)
         freqs = torch.einsum("i,j -> ij", t, inv_freq)
@@ -575,15 +584,17 @@ class Phi3LongRoPERotaryEmbedding(nn.Module):
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         query = query.view(*query.shape[:-1], -1, self.head_size)
         key = key.view(*key.shape[:-1], -1, self.head_size)
+
         k = self.original_max_position_embeddings
         long_prompt_offset = (torch.any(positions > k).float() *
                               torch.full_like(positions, k)).long()
         idx = (torch.add(positions, long_prompt_offset)
                if long_prompt_offset is not None else positions)
-        self.long_short_cos_sin_cache = self.long_short_cos_sin_cache.to(
-            idx.device)
+        self.long_short_cos_sin_cache: torch.Tensor = (
+            self.long_short_cos_sin_cache.to(idx.device))
         idx = torch.add(idx, offsets) if offsets is not None else idx
         cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
+
         cos, sin = cos_sin.chunk(2, dim=-1)
         cos = cos.repeat(1, 2).unsqueeze(-2)
         sin = sin.repeat(1, 2).unsqueeze(-2)
@@ -820,7 +831,7 @@ def get_rope(
                 for k, v in rope_scaling.items()
                 if k in ("short_mscale", "long_mscale")
             }
-            rotary_emb = Phi3LongRoPERotaryEmbedding(
+            rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
                 head_size, rotary_dim, max_position, original_max_position,
                 base, is_neox_style, dtype, short_factor, long_factor,
                 **extra_kwargs)

+ 37 - 0
aphrodite/modeling/models/clip.py

@@ -8,10 +8,14 @@ from PIL import Image
 from transformers import CLIPVisionConfig
 from transformers.models.clip.modeling_clip import CLIPAttention
 
+from aphrodite.common.config import ModelConfig
 from aphrodite.common.sequence import SequenceData
+from aphrodite.inputs import LLMInputs
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
+from aphrodite.multimodal.image import (cached_get_tokenizer,
+                                        repeat_and_pad_image_tokens)
 from aphrodite.quantization import QuantizationConfig
 
 
@@ -64,6 +68,39 @@ def dummy_image_for_clip(
     return {"image": image}
 
 
+def input_processor_for_clip(
+    model_config: ModelConfig,
+    hf_config: CLIPVisionConfig,
+    llm_inputs: LLMInputs,
+    *,
+    image_token_id: int,
+    image_feature_size_override: Optional[int] = None,
+):
+    multi_modal_data = llm_inputs.get("multi_modal_data")
+    if multi_modal_data is None or "image" not in multi_modal_data:
+        return llm_inputs
+
+    tokenizer = cached_get_tokenizer(model_config.tokenizer)
+
+    if image_feature_size_override is None:
+        image_feature_size = get_clip_image_feature_size(hf_config)
+    else:
+        image_feature_size = image_feature_size_override
+
+    new_prompt, new_token_ids = repeat_and_pad_image_tokens(
+        tokenizer,
+        llm_inputs.get("prompt"),
+        llm_inputs["prompt_token_ids"],
+        image_token_id=image_token_id,
+        repeat_count=image_feature_size,
+    )
+
+    # NOTE: Create a defensive copy of the original inputs
+    return LLMInputs(prompt_token_ids=new_token_ids,
+                     prompt=new_prompt,
+                     multi_modal_data=multi_modal_data)
+
+
 # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
 class CLIPVisionEmbeddings(nn.Module):
 

+ 28 - 24
aphrodite/modeling/models/llava.py

@@ -7,7 +7,7 @@ from transformers import CLIPVisionConfig, LlavaConfig
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, VisionLanguageConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
-from aphrodite.inputs import INPUT_REGISTRY, InputContext
+from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.sampler import Sampler
@@ -19,8 +19,10 @@ from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import MULTIMODAL_REGISTRY
 from aphrodite.quantization.base_config import QuantizationConfig
 
-from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
+from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
+                   input_processor_for_clip)
 from .interfaces import SupportsVision
+from .utils import merge_vision_embeddings
 
 _KEYS_TO_MODIFY_MAPPING = {
     "language_model.lm_head": "lm_head",
@@ -50,28 +52,10 @@ class LlavaMultiModalProjector(nn.Module):
         return hidden_states
 
 
-def merge_vision_embeddings(input_ids: torch.Tensor,
-                            inputs_embeds: torch.Tensor,
-                            vision_embeddings: torch.Tensor,
-                            image_token_id: int) -> torch.Tensor:
-    """In place merges in vision_embeddings with inputs_embeds."""
-    mask = (input_ids == image_token_id)
-
-    image_feature_size = vision_embeddings.shape[0] * vision_embeddings.shape[1]
-    if mask.sum() != image_feature_size:
-        raise ValueError(f"image_feature_size should be {image_feature_size}, "
-                         f"but found: {mask.sum()}")
-
-    inputs_embeds[mask] = vision_embeddings.view(image_feature_size,
-                                                 vision_embeddings.shape[-1])
-
-    return inputs_embeds
-
-
 class LlavaImagePixelInputs(TypedDict):
     type: Literal["pixel_values"]
     data: torch.Tensor
-    """Shape: (batch_size, num_channels, height, width)"""
+    """Shape: `(batch_size, num_channels, height, width)`"""
 
 
 LlavaImageInputs = LlavaImagePixelInputs
@@ -89,15 +73,36 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int):
         )
 
         mm_data = dummy_image_for_clip(vision_config)
-
         return seq_data, mm_data
 
     msg = f"Unsupported vision config: {type(vision_config)}"
     raise NotImplementedError(msg)
 
 
+def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
+    multi_modal_data = llm_inputs.get("multi_modal_data")
+    if multi_modal_data is None or "image" not in multi_modal_data:
+        return llm_inputs
+
+    model_config = ctx.model_config
+    hf_config = ctx.get_hf_config(LlavaConfig)
+    vision_config = hf_config.vision_config
+
+    if isinstance(vision_config, CLIPVisionConfig):
+        return input_processor_for_clip(
+            model_config,
+            vision_config,
+            llm_inputs,
+            image_token_id=hf_config.image_token_index,
+        )
+
+    msg = f"Unsupported vision config: {type(vision_config)}"
+    raise NotImplementedError(msg)
+
+
 @MULTIMODAL_REGISTRY.register_image_input_mapper()
 @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
+@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
 class LlavaForConditionalGeneration(nn.Module, SupportsVision):
 
     def __init__(self,
@@ -112,7 +117,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
 
         # TODO: Optionally initializes this for supporting embeddings.
         self.vision_tower = CLIPVisionModel(config.vision_config)
-
         self.multi_modal_projector = LlavaMultiModalProjector(
             vision_hidden_size=config.vision_config.hidden_size,
             text_hidden_size=config.text_config.hidden_size,
@@ -138,7 +142,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
                 f"The expected image tensor shape is batch dimension plus "
                 f"{self.vlm_config.image_input_shape[1:]}. "
                 f"You supplied {data.shape}. "
-                f"If you are using Aphrodite's entrypoint, make sure your "
+                f"If you are using vLLM's entrypoint, make sure your "
                 f"supplied image input is consistent with "
                 f"image_input_shape in engine args.")
 

+ 110 - 79
aphrodite/modeling/models/llava_next.py

@@ -1,8 +1,7 @@
-from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict
+from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
 
 import torch
 import torch.nn as nn
-from loguru import logger
 from PIL import Image
 from transformers import CLIPVisionConfig, LlavaNextConfig
 from transformers.models.llava_next.modeling_llava_next import (
@@ -11,22 +10,23 @@ from typing_extensions import NotRequired
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, VisionLanguageConfig
-from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
-from aphrodite.inputs import INPUT_REGISTRY, InputContext
+from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.quantization.base_config import (QuantizationConfig)
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.clip import CLIPVisionModel
 from aphrodite.modeling.models.llama import LlamaModel
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.multimodal import MULTIMODAL_REGISTRY
-from aphrodite.quantization.base_config import QuantizationConfig
+from aphrodite.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 
 from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
-                   get_clip_patch_grid_length)
+                   get_clip_patch_grid_length, input_processor_for_clip)
 from .interfaces import SupportsVision
-from .llava import LlavaMultiModalProjector, merge_vision_embeddings
+from .llava import LlavaMultiModalProjector
+from .utils import merge_vision_embeddings
 
 _KEYS_TO_MODIFY_MAPPING = {
     "language_model.lm_head": "lm_head",
@@ -36,16 +36,27 @@ _KEYS_TO_MODIFY_MAPPING = {
 
 class LlavaNextImagePixelInputs(TypedDict):
     type: Literal["pixel_values"]
-    data: torch.Tensor
-    """Shape: (batch_size, 1 + num_patches, num_channels, height, width)"""
+    data: BatchedTensors
+    """
+    Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
+
+    Note that `num_patches` may be different for each batch.
+    """
 
     image_sizes: NotRequired[torch.Tensor]
-    """Shape: (batch_size, 2)"""
+    """
+    Shape: `(batch_size, 2)`
+
+    This should be in `(height, width)` format.
+    """
 
 
 LlavaNextImageInputs = LlavaNextImagePixelInputs
 
 
+# Taken from: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L91
+# NOTE: new_height and new_width are further incremented to properly invert the
+# floordiv operation: https://github.com/huggingface/transformers/blob/v4.42.2/src/transformers/models/llava_next/modeling_llava_next.py#L133
 def _get_llava_next_num_unpadded_features(
     height: int,
     width: int,
@@ -53,7 +64,6 @@ def _get_llava_next_num_unpadded_features(
     num_patch_height: int,
     num_patch_width: int,
 ) -> Tuple[int, int]:
-    # Taken from: https://github.com/huggingface/text-generation-inference/blob/799a193b109662743bed1b18a09af1fdcd508c8b/server/text_generation_server/models/vlm_causal_lm.py#L111
     current_height = npatches * num_patch_height
     current_width = npatches * num_patch_width
 
@@ -61,9 +71,13 @@ def _get_llava_next_num_unpadded_features(
     current_aspect_ratio: float = current_width / current_height
     if aspect_ratio > current_aspect_ratio:
         new_height = (height * current_width) // width
+        if new_height % 2 == 1:
+            new_height += 1
         current_height = new_height
     else:
         new_width = (width * current_height) // height
+        if new_width % 2 == 1:
+            new_width += 1
         current_width = new_width
 
     unpadded_features = current_height * current_width
@@ -71,7 +85,8 @@ def _get_llava_next_num_unpadded_features(
     return (unpadded_features, newline_features)
 
 
-def _get_llava_next_image_feature_size(
+# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L111
+def get_llava_next_image_feature_size(
     hf_config: LlavaNextConfig,
     *,
     input_height: int,
@@ -86,7 +101,9 @@ def _get_llava_next_image_feature_size(
         )
         base_feature_size = num_patches * num_patches
 
-        num_patch_height, num_patch_width = get_anyres_image_grid_shape(
+        # Note: We follow the "wrong" width/height order
+        # [ref: PR huggingface/transformers#31588]
+        num_patch_width, num_patch_height = get_anyres_image_grid_shape(
             image_size=(input_height, input_width),
             grid_pinpoints=hf_config.image_grid_pinpoints,
             patch_size=vision_config.image_size,
@@ -107,14 +124,16 @@ def _get_llava_next_image_feature_size(
 
 
 def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
-    multimodal_config = ctx.get_multimodal_config()
     hf_config = ctx.get_hf_config(LlavaNextConfig)
     vision_config = hf_config.vision_config
 
-    #TODO: change the logic for dummy data to support dynamic shape
-    _, _, dummy_height, dummy_width = multimodal_config.image_input_shape
-    image_feature_size = _get_llava_next_image_feature_size(
-        hf_config, input_height=dummy_height, input_width=dummy_width)
+    # Result in the max possible feature size (2x2 grid of 336x336px tiles)
+    dummy_height = dummy_width = 448
+    image_feature_size = get_llava_next_image_feature_size(
+        hf_config,
+        input_height=dummy_height,
+        input_width=dummy_width,
+    )
 
     if isinstance(vision_config, CLIPVisionConfig):
         seq_data = dummy_seq_data_for_clip(
@@ -136,27 +155,47 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
     raise NotImplementedError(msg)
 
 
-def _pixel_mapper(ctx: InputContext, image: object) -> Dict[str, torch.Tensor]:
+def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
+    multi_modal_data = llm_inputs.get("multi_modal_data")
+    if multi_modal_data is None or "image" not in multi_modal_data:
+        return llm_inputs
 
-    if isinstance(image, Image.Image):
+    model_config = ctx.model_config
+    hf_config = ctx.get_hf_config(LlavaNextConfig)
+    vision_config = hf_config.vision_config
 
-        # Temporary patch before dynamic number of image tokens is supported
-        _, _, h, w = ctx.get_multimodal_config().image_input_shape
-        if (w, h) != (image.width, image.height):
-            logger.warning(
-                "Dynamic image shape is currently not supported. "
-                "Resizing input image to (%d, %d).", w, h)
+    image_data = multi_modal_data["image"]
+    if isinstance(image_data, Image.Image):
+        width, height = image_data.size
 
-            image = image.resize((w, h))
+        image_feature_size = get_llava_next_image_feature_size(
+            hf_config,
+            input_height=height,
+            input_width=width,
+        )
+    elif isinstance(image_data, torch.Tensor):
+        raise NotImplementedError("Embeddings input is not supported yet")
+    else:
+        raise TypeError(f"Invalid image type: {type(image_data)}")
 
-        return MULTIMODAL_REGISTRY._get_plugin("image") \
-            ._default_input_mapper(ctx, image)
+    vision_config = hf_config.vision_config
+
+    if isinstance(vision_config, CLIPVisionConfig):
+        return input_processor_for_clip(
+            model_config,
+            vision_config,
+            llm_inputs,
+            image_token_id=hf_config.image_token_index,
+            image_feature_size_override=image_feature_size,
+        )
 
-    raise TypeError(f"Invalid type for 'image': {type(image)}")
+    msg = f"Unsupported vision config: {type(vision_config)}"
+    raise NotImplementedError(msg)
 
 
-@MULTIMODAL_REGISTRY.register_image_input_mapper(_pixel_mapper)
+@MULTIMODAL_REGISTRY.register_image_input_mapper()
 @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
+@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
 class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
 
     def __init__(self,
@@ -169,8 +208,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
         self.config = config
         self.vlm_config = vlm_config
 
+        # TODO: Optionally initializes this for supporting embeddings.
         self.vision_tower = CLIPVisionModel(config=config.vision_config)
-
         self.multi_modal_projector = LlavaMultiModalProjector(
             vision_hidden_size=config.vision_config.hidden_size,
             text_hidden_size=config.text_config.hidden_size,
@@ -193,24 +232,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
         self.image_newline = nn.Parameter(
             torch.empty(config.text_config.hidden_size))
 
-    def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor:
-        _, num_channels, _, _ = self.vlm_config.image_input_shape
-
-        # Note that this is different from that of vLLM vision_language_config
-        # since the image is resized by the HuggingFace preprocessor
-        height = width = self.config.vision_config.image_size
-
-        if list(data.shape[2:]) != [num_channels, height, width]:
-            raise ValueError(
-                f"The expected image tensor shape is batch dimension plus "
-                f"num_patches plus {[num_channels, height, width]}. "
-                f"You supplied {data.shape}. "
-                f"If you are using vLLM's entrypoint, make sure your "
-                f"supplied image input is consistent with "
-                f"image_input_shape in engine args.")
-
-        return data
-
     def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
         if list(data.shape[1:]) != [2]:
             raise ValueError(
@@ -220,14 +241,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
         return data
 
     def _parse_and_validate_image_input(
-            self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
+            self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
         pixel_values = kwargs.pop("pixel_values", None)
         image_sizes = kwargs.pop("image_sizes", None)
 
-        if pixel_values is None or image_sizes is None:
+        if pixel_values is None:
             return None
 
-        if not isinstance(pixel_values, torch.Tensor):
+        if not isinstance(pixel_values, (torch.Tensor, list)):
             raise ValueError("Incorrect type of pixel values. "
                              f"Got type: {type(pixel_values)}")
 
@@ -237,7 +258,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
 
         return LlavaNextImagePixelInputs(
             type="pixel_values",
-            data=self._validate_image_pixels(pixel_values),
+            data=pixel_values,
             image_sizes=self._validate_image_sizes(image_sizes),
         )
 
@@ -264,15 +285,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
             strategy=self.config.vision_feature_select_strategy,
         )
 
+    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
     def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
                                       patch_embeddings: torch.Tensor, *,
                                       strategy: str) -> torch.Tensor:
-        # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
         if strategy == "flat":
             return patch_embeddings.flatten(0, 1)
 
         if strategy.startswith("spatial"):
-            orig_width, orig_height = image_size
             height = width = self.config.vision_config.image_size \
                 // self.config.vision_config.patch_size
 
@@ -286,13 +306,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
                 other_patch_embeds = patch_embeddings[1:]
 
                 # image_aspect_ratio == "anyres"
+                # Note: We follow the "wrong" width/height order
+                # [ref: PR huggingface/transformers#31588]
                 num_patch_width, num_patch_height = get_anyres_image_grid_shape(
-                    (orig_width, orig_height),
+                    image_size,
                     self.config.image_grid_pinpoints,
                     self.config.vision_config.image_size,
                 )
                 other_patch_embeds = other_patch_embeds \
-                    .view(num_patch_width, num_patch_height, height, width, -1)
+                    .view(num_patch_height, num_patch_width, height, width, -1)
 
                 if "unpad" in strategy:
                     other_patch_embeds = other_patch_embeds \
@@ -330,44 +352,53 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
         raise ValueError(f"Unexpected patch merge strategy: {strategy}")
 
     def _process_image_pixels(
-            self, inputs: LlavaNextImagePixelInputs) -> torch.Tensor:
+        self,
+        inputs: LlavaNextImagePixelInputs,
+    ) -> BatchedTensors:
         assert self.vision_tower is not None
 
         pixel_values = inputs["data"]
 
-        b, num_patches, c, h, w = pixel_values.shape
-        stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
+        if isinstance(pixel_values, torch.Tensor):
+            b, num_patches, c, h, w = pixel_values.shape
+            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
+            stacked_image_features = self._image_pixels_to_features(
+                self.vision_tower, stacked_pixel_values)
+            stacked_patch_embeddings = self.multi_modal_projector(
+                stacked_image_features)
 
+            return stacked_patch_embeddings.view(
+                b, num_patches, *stacked_patch_embeddings.shape[1:])
+
+        num_patches_per_batch = [v.shape[0] for v in pixel_values]
+        stacked_pixel_values = torch.cat(pixel_values)
         stacked_image_features = self._image_pixels_to_features(
             self.vision_tower, stacked_pixel_values)
 
-        return stacked_image_features.view(b, num_patches,
-                                           *stacked_image_features.shape[-2:])
+        return [
+            self.multi_modal_projector(image_features) for image_features in
+            torch.split(stacked_image_features, num_patches_per_batch)
+        ]
 
     def _process_image_input(
-            self, image_input: LlavaNextImageInputs) -> torch.Tensor:
-        assert self.vision_tower is not None
-        image_features = self._process_image_pixels(image_input)
-
-        patch_embeddings = self.multi_modal_projector(image_features)
+            self, image_input: LlavaNextImageInputs) -> BatchedTensors:
+        patch_embeddings = self._process_image_pixels(image_input)
 
         image_sizes = image_input.get("image_sizes")
         if image_sizes is None:
-            batch_size = image_input["data"].shape[0]
+            batch_size = len(image_input["data"])
             vision_config = self.config.vision_config
-            default_width = default_height = vision_config.image_size
-            image_sizes = torch.as_tensor([[default_width, default_height]
+            default_height = default_width = vision_config.image_size
+            image_sizes = torch.as_tensor([[default_height, default_width]
                                            for _ in range(batch_size)])
 
-        merged_patch_embeddings = [
+        return [
             self._merge_image_patch_embeddings(image_sizes[i],
-                                               patch_features,
+                                               patch_features_batch,
                                                strategy="spatial_unpad")
-            for i, patch_features in enumerate(patch_embeddings)
+            for i, patch_features_batch in enumerate(patch_embeddings)
         ]
 
-        return torch.stack(merged_patch_embeddings, dim=0)
-
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -401,8 +432,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
             input_ids: Flattened (concatenated) input_ids corresponding to a
                 batch.
             pixel_values: The pixels in each grid patch for each input image.
-                Expects a batch with shape `[1, num_patches, 3, 336, 336]`.
-            image_sizes: The original `(width, height)` for each input image.
+                Expects a batch with shape `[1, num_patches, 3, h, w]`.
+            image_sizes: The original `(height, width)` for each input image.
                 Expects a batch with shape `[1, 2]`.
 
         See also:

+ 158 - 64
aphrodite/modeling/models/phi3v.py

@@ -14,7 +14,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict
+import re
+from functools import lru_cache
+from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
 
 import numpy as np
 import torch
@@ -24,9 +26,10 @@ from PIL import Image
 from transformers import CLIPVisionConfig, PretrainedConfig
 
 from aphrodite.attention import AttentionMetadata
-from aphrodite.common.config import CacheConfig, VisionLanguageConfig
+from aphrodite.common.config import (CacheConfig, ModelConfig,
+                                     VisionLanguageConfig)
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
-from aphrodite.inputs import INPUT_REGISTRY, InputContext
+from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
@@ -34,10 +37,12 @@ from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.clip import CLIPVisionModel
 from aphrodite.modeling.models.llama import LlamaModel
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.multimodal import MULTIMODAL_REGISTRY
+from aphrodite.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
+from aphrodite.multimodal.image import cached_get_tokenizer
 from aphrodite.quantization.base_config import QuantizationConfig
 
-from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
+from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
+                   input_processor_for_clip)
 from .interfaces import SupportsVision
 
 _KEYS_TO_MODIFY_MAPPING = {
@@ -249,50 +254,22 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
 
 class Phi3VImagePixelInputs(TypedDict):
     type: Literal["pixel_values"]
-    data: torch.Tensor
-    """Shape: (batch_size, 1 + num_patches, num_channels, height, width)"""
-
-    image_sizes: torch.Tensor
-    """Shape: (batch_size, 2)"""
-
-
-def _get_phi3v_image_feature_size(
-    *,
-    input_height: int,
-    input_width: int,
-) -> int:
-    h, w = input_height, input_width
-
-    # https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L178
-    return (h // 336 * w // 336 + 1) * 144 + 1 + (h // 336 + 1) * 12
+    data: BatchedTensors
+    """
+    Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
 
+    Note that `num_patches` may be different for each batch.
+    """
 
-def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
-    multimodal_config = ctx.get_multimodal_config()
-
-    #TODO: change the logic for dummy data to support dynamic shape
-    _, _, dummy_height, dummy_width = multimodal_config.image_input_shape
-    image_feature_size = _get_phi3v_image_feature_size(
-        input_height=dummy_height,
-        input_width=dummy_width,
-    )
-
-    seq_data = dummy_seq_data_for_clip(
-        CLIP_VIT_LARGE_PATCH14_336_CONFIG,
-        seq_len,
-        image_token_id=32044,
-        image_feature_size_override=image_feature_size,
-    )
-    mm_data = dummy_image_for_clip(
-        CLIP_VIT_LARGE_PATCH14_336_CONFIG,
-        image_width_override=dummy_width,
-        image_height_override=dummy_height,
-    )
+    image_sizes: torch.Tensor
+    """
+    Shape: `(batch_size, 2)`
 
-    return seq_data, mm_data
+    This should be in `(height, width)` format.
+    """
 
 
-# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
+# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
 def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
     target_height = int(np.ceil(height / padding_unit) * padding_unit)
     top_padding = int((target_height - height) / 2)
@@ -302,7 +279,7 @@ def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
     return padded_width, padded_height
 
 
-# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
+# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
 def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
     transposed = False
     if width < height:
@@ -327,26 +304,133 @@ def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
     return padded_width, padded_height
 
 
-def _image_processor(ctx: InputContext,
-                     image: object) -> Dict[str, torch.Tensor]:
+# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
+def get_phi3v_image_feature_size(
+    hf_config: PretrainedConfig,
+    *,
+    input_height: int,
+    input_width: int,
+) -> int:
+    num_crops = getattr(hf_config, "num_crops", 16)
+    new_width, new_height = _calc_hd_transform_size(width=input_width,
+                                                    height=input_height,
+                                                    hd_num=num_crops)
 
-    if isinstance(image, Image.Image):
-        # Temporary patch before dynamic number of image tokens is supported
-        _, _, h, w = ctx.get_multimodal_config().image_input_shape
-        if (w, h) != _calc_hd_transform_size(width=image.width,
-                                             height=image.height):
-            logger.warning("Dynamic image shape is currently not supported. "
-                           f"Resizing input image to ({w}, {h}).")
+    return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \
+        + (new_height // 336 + 1) * 12
 
-            image = image.resize((w, h))
 
-        return MULTIMODAL_REGISTRY._get_plugin("image") \
-                ._default_input_mapper(ctx, image)
-    raise TypeError(f"Invalid type for 'image': {type(image)}")
+def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
+    # Result in the max possible feature size (h:w = 16:1)
+    dummy_height, dummy_width = 8000, 50
+    image_feature_size = get_phi3v_image_feature_size(
+        ctx.get_hf_config(PretrainedConfig),
+        input_height=dummy_height,
+        input_width=dummy_width,
+    )
+
+    seq_data = dummy_seq_data_for_clip(
+        CLIP_VIT_LARGE_PATCH14_336_CONFIG,
+        seq_len,
+        image_token_id=32044,
+        image_feature_size_override=image_feature_size,
+    )
+    mm_data = dummy_image_for_clip(
+        CLIP_VIT_LARGE_PATCH14_336_CONFIG,
+        image_width_override=dummy_width,
+        image_height_override=dummy_height,
+    )
+
+    return seq_data, mm_data
+
 
+# Reserve this function to also handle placeholders for additional images
+# [ref: PR #5820]
+@lru_cache
+def _get_image_placeholder_token_ids(model_config: ModelConfig,
+                                     idx: int) -> List[int]:
+    assert idx > 0
 
-@MULTIMODAL_REGISTRY.register_image_input_mapper(_image_processor)
+    tokenizer = cached_get_tokenizer(model_config.tokenizer)
+
+    # We need to get the token for "<", not "▁<"
+    # https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
+    a_token_id, = tokenizer.encode("a", add_special_tokens=False)
+    a_token_id_, *image_placeholder_token_ids = tokenizer.encode(
+        f"a<|image_{idx}|>", add_special_tokens=False)
+    assert a_token_id == a_token_id_
+
+    return image_placeholder_token_ids
+
+
+def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
+    multi_modal_data = llm_inputs.get("multi_modal_data")
+    if multi_modal_data is None or "image" not in multi_modal_data:
+        return llm_inputs
+
+    model_config = ctx.model_config
+    multimodal_config = ctx.get_multimodal_config()
+    hf_config = ctx.get_hf_config(PretrainedConfig)
+
+    image_data = multi_modal_data["image"]
+    if isinstance(image_data, Image.Image):
+        w, h = image_data.size
+        w, h = _calc_hd_transform_size(width=w, height=h)
+
+        image_feature_size = get_phi3v_image_feature_size(hf_config,
+                                                          input_width=w,
+                                                          input_height=h)
+    elif isinstance(image_data, torch.Tensor):
+        raise NotImplementedError("Embeddings input is not supported yet")
+    else:
+        raise TypeError(f"Invalid image type: {type(image_data)}")
+
+    prompt = llm_inputs.get("prompt")
+    if prompt is None:
+        new_prompt = None
+    else:
+        if prompt.count("<|image|>") > 0:
+            logger.warning("Please follow the prompt format that is "
+                           "documented on HuggingFace which does not involve "
+                           "repeating <|image|> tokens.")
+        elif len(re.findall(r"(<\|image_\d+\|>)+", prompt)) > 1:
+            logger.warning("Multiple image input is not supported yet, "
+                           "so any extra image tokens will be treated "
+                           "as plain text.")
+
+        new_prompt = prompt
+
+    prompt_token_ids = llm_inputs["prompt_token_ids"]
+    image_1_token_ids = _get_image_placeholder_token_ids(model_config, idx=1)
+
+    new_token_ids: List[int] = []
+    for i in range(len(prompt_token_ids) - len(image_1_token_ids) + 1):
+        if prompt_token_ids[i:i + len(image_1_token_ids)] == image_1_token_ids:
+            new_token_ids.append(multimodal_config.image_token_id)
+
+            # No need to further scan the list since we only replace once
+            new_token_ids.extend(prompt_token_ids[i + len(image_1_token_ids):])
+            break
+        else:
+            new_token_ids.append(prompt_token_ids[i])
+
+    # NOTE: Create a defensive copy of the original inputs
+    llm_inputs = LLMInputs(prompt_token_ids=new_token_ids,
+                           prompt=new_prompt,
+                           multi_modal_data=multi_modal_data)
+
+    return input_processor_for_clip(
+        model_config,
+        CLIP_VIT_LARGE_PATCH14_336_CONFIG,
+        llm_inputs,
+        image_token_id=multimodal_config.image_token_id,
+        image_feature_size_override=image_feature_size,
+    )
+
+
+@MULTIMODAL_REGISTRY.register_image_input_mapper()
 @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
+@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
 class Phi3VForCausalLM(nn.Module, SupportsVision):
 
     def __init__(self,
@@ -360,6 +444,8 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
         self.vlm_config = vlm_config
 
         self.model = LlamaModel(config, cache_config, quant_config)
+
+        # TODO: Optionally initializes this for supporting embeddings.
         self.vision_embed_tokens = Phi3HDImageEmbedding(
             vlm_config, config, self.model.embed_tokens)
         self.lm_head = ParallelLMHead(config.vocab_size,
@@ -373,12 +459,20 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
         pixel_values = kwargs.pop("pixel_values", None)
         image_sizes = kwargs.pop("image_sizes", None)
 
-        if pixel_values is not None and image_sizes is not None:
-            return Phi3VImagePixelInputs(type="pixel_values",
-                                         data=pixel_values,
-                                         image_sizes=image_sizes)
+        if pixel_values is None:
+            return None
+
+        if not isinstance(pixel_values, (torch.Tensor, list)):
+            raise ValueError("Incorrect type of pixel values. "
+                             f"Got type: {type(pixel_values)}")
+
+        if not isinstance(image_sizes, torch.Tensor):
+            raise ValueError("Incorrect type of image sizes. "
+                             f"Got type: {type(image_sizes)}")
 
-        return None
+        return Phi3VImagePixelInputs(type="pixel_values",
+                                     data=pixel_values,
+                                     image_sizes=image_sizes)
 
     def forward(self,
                 input_ids: torch.Tensor,

+ 41 - 0
aphrodite/modeling/models/utils.py

@@ -0,0 +1,41 @@
+import torch
+
+from aphrodite.multimodal import BatchedTensors
+
+
+def merge_vision_embeddings(input_ids: torch.Tensor,
+                            inputs_embeds: torch.Tensor,
+                            vision_embeddings: BatchedTensors,
+                            image_token_id: int) -> torch.Tensor:
+    """
+    Merge `vision_embeddings` into `inputs_embeds` by overwriting the positions
+    in `inputs_embeds` corresponding to placeholder image tokens in `input_ids`.
+
+    Note:
+        This updates `inputs_embeds` in place.
+    """
+    mask = (input_ids == image_token_id)
+    num_expected_tokens = mask.sum()
+
+    if isinstance(vision_embeddings, torch.Tensor):
+        batch_size, batch_tokens, *_, embed_dim = vision_embeddings.shape
+        total_tokens = batch_size * batch_tokens
+        if num_expected_tokens != total_tokens:
+            expr = f"{batch_size} x {batch_tokens}"
+            raise ValueError(
+                f"Attempted to assign {expr} = {total_tokens} "
+                f"image tokens to {num_expected_tokens} placeholders")
+
+        inputs_embeds[mask] = vision_embeddings.view(total_tokens, embed_dim)
+    else:
+        size_per_batch = [t.shape[0] for t in vision_embeddings]
+        total_tokens = sum(size_per_batch)
+        if num_expected_tokens != total_tokens:
+            expr = ' + '.join(map(str, size_per_batch))
+            raise ValueError(
+                f"Attempted to assign {expr} = {total_tokens} "
+                f"image tokens to {num_expected_tokens} placeholders")
+
+        inputs_embeds[mask] = torch.cat(vision_embeddings)
+
+    return inputs_embeds

+ 5 - 2
aphrodite/multimodal/__init__.py

@@ -1,4 +1,5 @@
-from .base import MultiModalDataDict, MultiModalPlugin
+from .base import (BatchedTensors, MultiModalDataDict, MultiModalInputs,
+                   MultiModalPlugin)
 from .registry import MultiModalRegistry
 
 MULTIMODAL_REGISTRY = MultiModalRegistry()
@@ -10,8 +11,10 @@ See also:
 """
 
 __all__ = [
+    "BatchedTensors",
+    "MultiModalDataDict",
+    "MultiModalInputs",
     "MultiModalPlugin",
     "MULTIMODAL_REGISTRY",
     "MultiModalRegistry",
-    "MultiModalDataDict",
 ]

+ 83 - 15
aphrodite/multimodal/base.py

@@ -1,22 +1,88 @@
+import sys
 from abc import ABC, abstractmethod
-from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Type,
-                    TypedDict, TypeVar, Union)
+from collections import UserDict, defaultdict
+from typing import (Any, Callable, Dict, List, Optional, Type, TypedDict,
+                    TypeVar, Union)
 
+import torch
+import torch.types
 from loguru import logger
+from PIL import Image
+from torch import nn
 
 from aphrodite.common.config import ModelConfig
 from aphrodite.inputs import InputContext
 
-if TYPE_CHECKING:
-    import torch
-    from PIL import Image
-    from torch import nn
+BatchedTensors = Union[torch.Tensor, List[torch.Tensor]]
+"""
+If each input tensor in the batch has the same size, this is a single batched
+tensor; otherwise, this is a list of tensors with one element per batch.
+"""
+
+if sys.version_info < (3, 9):
+    # UserDict cannot be subscripted
+    class _MultiModalInputsBase(UserDict):
+        pass
+else:
+
+    class _MultiModalInputsBase(UserDict[str, torch.Tensor]):
+        pass
+
+
+class MultiModalInputs(_MultiModalInputsBase):
+    """
+    A dictionary that represents the keyword arguments to
+    :meth:`~torch.nn.Module.forward`.
+    """
+
+    @staticmethod
+    def try_concat(
+        tensors: List[torch.Tensor],
+        *,
+        device: torch.types.Device,
+    ) -> BatchedTensors:
+        # Avoid initializing CUDA too early
+        import torch
+
+        unbatched_shape = tensors[0].shape[1:]
 
-N = TypeVar("N", bound=Type["nn.Module"])
+        for tensor in tensors:
+            if tensor.shape[1:] != unbatched_shape:
+                return [
+                    tensor.squeeze(0).to(device=device) for tensor in tensors
+                ]
+
+        return torch.cat(tensors, dim=0).to(device=device)
+
+    @staticmethod
+    def batch(
+        inputs_list: List["MultiModalInputs"],
+        device: torch.types.Device,
+    ) -> Dict[str, BatchedTensors]:
+        """Batch multiple inputs together into a dictionary."""
+        if len(inputs_list) == 0:
+            return {}
+
+        keys = inputs_list[0].keys()
+
+        item_lists: Dict[str, List[torch.Tensor]] = defaultdict(list)
+
+        for inputs in inputs_list:
+            if inputs.keys() != keys:
+                msg = f"Inputs do not share the same keys ({keys})"
+                raise ValueError(msg)
+
+            for k, v in inputs.items():
+                item_lists[k].append(v)
+
+        return {
+            k: MultiModalInputs.try_concat(item_list, device=device)
+            for k, item_list in item_lists.items()
+        }
 
 
 class MultiModalDataBuiltins(TypedDict, total=False):
-    image: "Image.Image"
+    image: Image.Image
 
 
 MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]
@@ -28,12 +94,13 @@ to the model by the corresponding mapper. By default, the mapper of
 the corresponding plugin with the same modality key is applied.
 """
 
-MultiModalInputMapper = Callable[[InputContext, object], Dict[str,
-                                                              "torch.Tensor"]]
+MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs]
 """Return a dictionary to be passed as keyword arguments to
 :meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
 and processors in HuggingFace Transformers."""
 
+N = TypeVar("N", bound=Type[nn.Module])
+
 
 class MultiModalPlugin(ABC):
     """
@@ -47,8 +114,7 @@ class MultiModalPlugin(ABC):
     """
 
     def __init__(self) -> None:
-        self._input_mappers: Dict[Type["nn.Module"],
-                                  MultiModalInputMapper] = {}
+        self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
 
     @abstractmethod
     def get_data_key(self) -> str:
@@ -59,7 +125,7 @@ class MultiModalPlugin(ABC):
 
     @abstractmethod
     def _default_input_mapper(self, ctx: InputContext,
-                              data: object) -> Dict[str, "torch.Tensor"]:
+                              data: object) -> MultiModalInputs:
         """Return a dictionary to be passed as keyword arguments to
         :meth:`~torch.nn.Module.forward`. This is similar in concept to
         tokenizers and processors in HuggingFace Transformers.
@@ -79,6 +145,7 @@ class MultiModalPlugin(ABC):
 
         See also:
             :ref:`input_processing_pipeline`
+            :ref:`adding_a_new_multimodal_model`
         """
 
         def wrapper(model_cls: N) -> N:
@@ -96,7 +163,7 @@ class MultiModalPlugin(ABC):
         return wrapper
 
     def map_input(self, model_config: ModelConfig,
-                  data: object) -> Dict[str, "torch.Tensor"]:
+                  data: object) -> MultiModalInputs:
         """
         Apply an input mapper to a data passed
         to the model, transforming the data into a dictionary of model inputs.
@@ -105,7 +172,8 @@ class MultiModalPlugin(ABC):
 
         The model is identified by ``model_config``.
 
-        TODO: Add guide [ref: PR #5276]
+        See also:
+            :ref:`adding_a_new_multimodal_model`
         """
         # Avoid circular import
         from aphrodite.modeling.model_loader import get_model_architecture

+ 94 - 6
aphrodite/multimodal/image.py

@@ -1,17 +1,100 @@
 from functools import lru_cache
-from typing import Dict
+from typing import List, Optional, Tuple, TypeVar
 
 import torch
 from loguru import logger
 from PIL import Image
+from transformers import PreTrainedTokenizerBase
 
 from aphrodite.common.config import ModelConfig
 from aphrodite.inputs.registry import InputContext
 from aphrodite.transformers_utils.image_processor import get_image_processor
+from aphrodite.transformers_utils.tokenizer import get_tokenizer
 
-from .base import MultiModalPlugin
+from .base import MultiModalInputs, MultiModalPlugin
 
 cached_get_image_processor = lru_cache(get_image_processor)
+cached_get_tokenizer = lru_cache(get_tokenizer)
+
+# Utilities for image input processors
+_T = TypeVar("_T", str, int)
+
+
+def repeat_and_pad_token(
+    token: _T,
+    *,
+    repeat_count: int = 1,
+    pad_token_left: Optional[_T] = None,
+    pad_token_right: Optional[_T] = None,
+) -> List[_T]:
+    replacement = [token] * repeat_count
+    if pad_token_left is not None:
+        replacement = [pad_token_left] + replacement
+    if pad_token_right is not None:
+        replacement = replacement + [pad_token_right]
+
+    return replacement
+
+
+def repeat_and_pad_image_tokens(
+    tokenizer: PreTrainedTokenizerBase,
+    prompt: Optional[str],
+    prompt_token_ids: List[int],
+    *,
+    image_token_id: int,
+    repeat_count: int = 1,
+    pad_token_left: Optional[int] = None,
+    pad_token_right: Optional[int] = None,
+) -> Tuple[Optional[str], List[int]]:
+    if prompt is None:
+        new_prompt = None
+    else:
+        image_token_str = tokenizer.decode(image_token_id)
+        pad_token_str_left = (None if pad_token_left is None else
+                              tokenizer.decode(pad_token_left))
+        pad_token_str_right = (None if pad_token_right is None else
+                               tokenizer.decode(pad_token_right))
+        replacement_str = "".join(
+            repeat_and_pad_token(
+                image_token_str,
+                repeat_count=repeat_count,
+                pad_token_left=pad_token_str_left,
+                pad_token_right=pad_token_str_right,
+            ))
+
+        image_token_count = prompt.count(image_token_str)
+        # This is an arbitrary number to distinguish between the two cases
+        if image_token_count > 16:
+            logger.warning(
+                "Please follow the prompt format that is "
+                "documented on HuggingFace which does not involve "
+                "repeating %s tokens.", image_token_str)
+        elif image_token_count > 1:
+            logger.warning("Multiple image input is not supported yet, "
+                           "so any extra image tokens will be treated "
+                           "as plain text.")
+
+        # The image tokens are removed to be consistent with HuggingFace
+        new_prompt = prompt.replace(image_token_str, replacement_str, 1)
+
+    new_token_ids: List[int] = []
+    for i, token in enumerate(prompt_token_ids):
+        if token == image_token_id:
+            replacement_ids = repeat_and_pad_token(
+                image_token_id,
+                repeat_count=repeat_count,
+                pad_token_left=pad_token_left,
+                pad_token_right=pad_token_right,
+            )
+            new_token_ids.extend(replacement_ids)
+
+            # No need to further scan the list since we only replace once
+            new_token_ids.extend(prompt_token_ids[i + 1:])
+            break
+        else:
+            new_token_ids.append(token)
+
+    return new_prompt, new_token_ids
 
 
 class ImagePlugin(MultiModalPlugin):
@@ -25,7 +108,7 @@ class ImagePlugin(MultiModalPlugin):
             trust_remote_code=model_config.trust_remote_code)
 
     def _default_input_mapper(self, ctx: InputContext,
-                              data: object) -> Dict[str, torch.Tensor]:
+                              data: object) -> MultiModalInputs:
         model_config = ctx.model_config
         if isinstance(data, Image.Image):
             image_processor = self._get_hf_image_processor(model_config)
@@ -33,10 +116,15 @@ class ImagePlugin(MultiModalPlugin):
                 raise RuntimeError("No HuggingFace processor is available"
                                    "to process the image object")
             try:
-                return image_processor.preprocess(data, return_tensors="pt") \
-                    .to(model_config.dtype).data
+                batch_data = image_processor \
+                    .preprocess(data, return_tensors="pt") \
+                    .data
             except Exception:
                 logger.error("Failed to process image (%s)", data)
                 raise
 
-        raise TypeError(f"Invalid type for 'image': {type(data)}")
+            return MultiModalInputs(batch_data)
+        elif isinstance(data, torch.Tensor):
+            raise NotImplementedError("Embeddings input is not supported yet")
+
+        raise TypeError(f"Invalid image type: {type(data)}")

+ 22 - 11
aphrodite/multimodal/registry.py

@@ -1,16 +1,15 @@
 import functools
-from typing import Optional, Sequence, Type, TypeVar
+from typing import Dict, Optional, Sequence
 
+import torch
 from loguru import logger
-from torch import nn
 
 from aphrodite.common.config import ModelConfig
 
-from .base import MultiModalDataDict, MultiModalInputMapper, MultiModalPlugin
+from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
+                   MultiModalPlugin)
 from .image import ImagePlugin
 
-N = TypeVar("N", bound=Type[nn.Module])
-
 
 class MultiModalRegistry:
     """
@@ -59,7 +58,7 @@ class MultiModalRegistry:
         return self.register_input_mapper("image", mapper)
 
     def _process_input(self, key: str, value: object,
-                       model_config: ModelConfig):
+                       model_config: ModelConfig) -> MultiModalInputs:
         plugin = self._plugins.get(key)
         if plugin:
             return plugin.map_input(model_config, value)
@@ -91,16 +90,28 @@ class MultiModalRegistry:
         """
         return self.register_input_mapper("image", mapper)
 
-    def map_input(self, model_config: ModelConfig, data: MultiModalDataDict):
+    def map_input(self, model_config: ModelConfig,
+                  data: MultiModalDataDict) -> MultiModalInputs:
         """
         Apply an input mapper to the data passed to the model.
         
         See :meth:`MultiModalPlugin.map_input` for more details.
         """
-        result_list = [
-            self._process_input(k, v, model_config) for k, v in data.items()
-        ]
-        return {k: v for d in result_list for k, v in d.items()}
+        merged_dict: Dict[str, torch.Tensor] = {}
+
+        for data_key, data_value in data.items():
+            input_dict = self._process_input(data_key, data_value,
+                                             model_config)
+
+            for input_key, input_tensor in input_dict.items():
+                if input_key in merged_dict:
+                    raise ValueError(f"The input mappers (keys={set(data)}) "
+                                     f"resulted in a conflicting keyword "
+                                     f"argument to `forward()`: {input_key}")
+
+                merged_dict[input_key] = input_tensor
+
+        return MultiModalInputs(merged_dict)
 
     def create_input_mapper(self, model_config: ModelConfig):
         """

+ 61 - 33
aphrodite/multimodal/utils.py

@@ -5,15 +5,60 @@ from typing import Optional, Union
 from urllib.parse import urlparse
 
 import aiohttp
+import requests
 from PIL import Image
 
-from aphrodite.common.config import ModelConfig
 from aphrodite.multimodal.base import MultiModalDataDict
+from aphrodite.version import __version__ as APHRODITE_VERSION
 
 APHRODITE_IMAGE_FETCH_TIMEOUT = int(
     os.getenv("APHRODITE_IMAGE_FETCH_TIMEOUT", 10))
 
 
+def _validate_remote_url(url: str, *, name: str):
+    parsed_url = urlparse(url)
+    if parsed_url.scheme not in ["http", "https"]:
+        raise ValueError(f"Invalid '{name}': A valid '{name}' "
+                         "must have scheme 'http' or 'https'.")
+
+
+def _get_request_headers():
+    return {"User-Agent": f"aphrodite/{APHRODITE_VERSION}"}
+
+
+def _load_image_from_bytes(b: bytes):
+    image = Image.open(BytesIO(b))
+    image.load()
+    return image
+
+
+def _load_image_from_data_url(image_url: str):
+    # Only split once and assume the second part is the base64 encoded image
+    _, image_base64 = image_url.split(",", 1)
+    return load_image_from_base64(image_base64)
+
+
+def fetch_image(image_url: str) -> Image.Image:
+    """Load PIL image from a url or base64 encoded openai GPT4V format"""
+    if image_url.startswith('http'):
+        _validate_remote_url(image_url, name="image_url")
+
+        headers = _get_request_headers()
+
+        with requests.get(url=image_url, headers=headers) as response:
+            response.raise_for_status()
+            image_raw = response.content
+        image = _load_image_from_bytes(image_raw)
+
+    elif image_url.startswith('data:image'):
+        image = _load_image_from_data_url(image_url)
+    else:
+        raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
+                         "with either 'data:image' or 'http'.")
+
+    return image
+
+
 class ImageFetchAiohttp:
     aiohttp_client: Optional[aiohttp.ClientSession] = None
 
@@ -33,34 +78,32 @@ class ImageFetchAiohttp:
         """Load PIL image from a url or base64 encoded openai GPT4V format"""
 
         if image_url.startswith('http'):
-            parsed_url = urlparse(image_url)
-            if parsed_url.scheme not in ["http", "https"]:
-                raise ValueError("Invalid 'image_url': A valid 'image_url' "
-                                 "must have scheme 'http' or 'https'.")
-            # Avoid circular import
-            from aphrodite import __version__ as APHRODITE_VERSION
+            _validate_remote_url(image_url, name="image_url")
 
             client = cls.get_aiohttp_client()
-            headers = {"User-Agent": f"aphrodite/{APHRODITE_VERSION}"}
+            headers = _get_request_headers()
 
             async with client.get(url=image_url, headers=headers) as response:
                 response.raise_for_status()
                 image_raw = await response.read()
-            image = Image.open(BytesIO(image_raw))
+            image = _load_image_from_bytes(image_raw)
 
         # Only split once and assume the second part is the base64 encoded image
         elif image_url.startswith('data:image'):
-            image = load_image_from_base64(image_url.split(',', 1)[1])
+            image = _load_image_from_data_url(image_url)
 
         else:
             raise ValueError(
                 "Invalid 'image_url': A valid 'image_url' must start "
                 "with either 'data:image' or 'http'.")
-
-        image.load()
         return image
 
 
+async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
+    image = await ImageFetchAiohttp.fetch_image(image_url)
+    return {"image": image}
+
+
 def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str:
     """Encode a pillow image to base64 format."""
 
@@ -73,26 +116,11 @@ def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str:
 
 def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
     """Load image from base64 format."""
-    return Image.open(BytesIO(base64.b64decode(image)))
+    return _load_image_from_bytes(base64.b64decode(image))
 
 
-# TODO: move this to a model registry for preprocessing vision
-# language prompts based on the model type.
-def get_full_image_text_prompt(image_prompt: str, text_prompt: str,
-                               config: ModelConfig) -> str:
-    """Combine image and text prompts for vision language model depending on
-    the model architecture."""
-
-    if config.hf_config.model_type in ("llava", "llava_next"):
-        full_prompt = f"{image_prompt}\n{text_prompt}"
-    elif config.hf_config.model_type == 'phi3_v':
-        full_prompt = f"{image_prompt}<s>\n{text_prompt}"
-    else:
-        raise ValueError(
-            f"Unsupported model type: {config.hf_config.model_type}")
-    return full_prompt
-
-
-async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
-    image = await ImageFetchAiohttp.fetch_image(image_url)
-    return {"image": image}
+def rescale_image_size(image: Image.Image, size_factor: float) -> Image.Image:
+    """Rescale the dimensions of an image by a constant factor."""
+    new_width = int(image.width * size_factor)
+    new_height = int(image.height * size_factor)
+    return image.resize((new_width, new_height))

+ 15 - 19
aphrodite/task_handler/cpu_model_runner.py

@@ -1,6 +1,6 @@
-from collections import defaultdict
 from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
+from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
+                    Type, Union)
 
 import torch
 from torch import nn
@@ -14,7 +14,8 @@ from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
 from aphrodite.common.utils import make_tensor_with_pad
 from aphrodite.modeling import SamplingMetadata
 from aphrodite.modeling.model_loader import get_model
-from aphrodite.multimodal import MULTIMODAL_REGISTRY
+from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
+                                  MultiModalInputs)
 from aphrodite.task_handler.model_runner_base import (
     ModelRunnerBase, ModelRunnerInputBase,
     _add_attn_metadata_broadcastable_dict,
@@ -37,7 +38,7 @@ class CPUModelInput(ModelRunnerInputBase):
     input_positions: Optional[torch.Tensor] = None
     attn_metadata: Optional["AttentionMetadata"] = None
     sampling_metadata: Optional["SamplingMetadata"] = None
-    multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
+    multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
 
     def as_broadcastable_tensor_dict(
             self) -> Dict[str, Union[int, torch.Tensor]]:
@@ -129,15 +130,14 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
     def _prepare_prompt(
         self,
         seq_group_metadata_list: List[SequenceGroupMetadata],
-    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], Dict[
-            str, torch.Tensor]]:
+    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
+               Mapping[str, BatchedTensors]]:
         assert len(seq_group_metadata_list) > 0
         input_tokens: List[int] = []
         input_positions: List[int] = []
         slot_mapping: List[int] = []
         seq_lens: List[int] = []
-        multi_modal_kwargs_list: Dict[str,
-                                      List[torch.Tensor]] = defaultdict(list)
+        multi_modal_inputs_list: List[MultiModalInputs] = []
 
         for seq_group_metadata in seq_group_metadata_list:
             assert seq_group_metadata.is_prompt
@@ -159,10 +159,9 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
             input_positions.extend(list(range(computed_len, seq_len)))
 
             mm_data = seq_group_metadata.multi_modal_data
-            if mm_data is not None:
+            if mm_data:
                 mm_kwargs = self.multi_modal_input_mapper(mm_data)
-                for k, v in mm_kwargs.items():
-                    multi_modal_kwargs_list[k].append(v)
+                multi_modal_inputs_list.append(mm_kwargs)
 
             # Compute the slot mapping.
             block_table = seq_group_metadata.block_tables[seq_id]
@@ -186,11 +185,6 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
                 slot = block_number * self.block_size + block_offset
                 slot_mapping.append(slot)
 
-        multi_modal_kwargs = {
-            k: torch.cat(v, dim=0).to(self.device)
-            for k, v in multi_modal_kwargs_list.items()
-        }
-
         num_prompt_tokens = len(input_tokens)
 
         input_tokens = torch.tensor(input_tokens,
@@ -214,6 +208,10 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
             block_tables=torch.tensor([]),
             slot_mapping=slot_mapping,
         )
+
+        multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
+                                                    device=self.device)
+
         return (input_tokens, input_positions, attn_metadata, seq_lens,
                 multi_modal_kwargs)
 
@@ -363,10 +361,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
             "positions": model_input.input_positions,
             "kv_caches": kv_caches,
             "attn_metadata": model_input.attn_metadata,
+            **(model_input.multi_modal_kwargs or {}),
         }
-        if (self.vision_language_config
-                and model_input.multi_modal_kwargs is not None):
-            execute_model_kwargs.update(model_input.multi_modal_kwargs)
 
         hidden_states = model_executable(**execute_model_kwargs)
 

+ 1 - 3
aphrodite/task_handler/embedding_model_runner.py

@@ -90,10 +90,8 @@ class EmbeddingModelRunner(
             "positions": model_input.input_positions,
             "kv_caches": kv_caches,
             "attn_metadata": model_input.attn_metadata,
+            **(model_input.multi_modal_kwargs or {}),
         }
-        if self.vision_language_config:
-            multi_modal_kwargs = model_input.multi_modal_kwargs or {}
-            execute_model_kwargs.update({"image_input": multi_modal_kwargs})
         hidden_states = model_executable(**execute_model_kwargs)
 
         # Only perform pooling in the driver worker.

+ 14 - 13
aphrodite/task_handler/model_runner.py

@@ -3,8 +3,8 @@ import gc
 import time
 import warnings
 from collections import defaultdict
-from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
-                    TypeVar, Union)
+from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
+                    Tuple, Type, TypeVar, Union)
 
 import numpy as np
 import torch
@@ -46,7 +46,8 @@ from aphrodite.modeling import SamplingMetadata
 from aphrodite.modeling.model_loader import get_model
 from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
 from aphrodite.modeling.models.interfaces import supports_lora
-from aphrodite.multimodal import MULTIMODAL_REGISTRY
+from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
+                                  MultiModalInputs)
 from aphrodite.task_handler.model_runner_base import (
     ModelRunnerBase, ModelRunnerInputBase,
     _add_attn_metadata_broadcastable_dict,
@@ -85,7 +86,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
     lora_mapping: Optional["LoRAMapping"] = None
     lora_requests: Optional[Set[LoRARequest]] = None
     attn_metadata: Optional["AttentionMetadata"] = None
-    multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
+    multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
     request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
     finished_requests_ids: Optional[List[str]] = None
     virtual_engine: int = 0
@@ -374,8 +375,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
         context_lens: List[int] = []
         query_lens: List[int] = []
         block_tables: List[List[int]] = []
-        multi_modal_kwargs_list: Dict[str,
-                                      List[torch.Tensor]] = defaultdict(list)
+        multi_modal_inputs_list: List[MultiModalInputs] = []
         request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list)
         decode_only = True
         num_prefills = 0
@@ -546,8 +546,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
                 if mm_data:
                     # Process multi-modal data
                     mm_kwargs = self.multi_modal_input_mapper(mm_data)
-                    for k, v in mm_kwargs.items():
-                        multi_modal_kwargs_list[k].append(v)
+                    multi_modal_inputs_list.append(mm_kwargs)
 
                 is_profile_run = _is_block_tables_empty(
                     seq_group_metadata.block_tables)
@@ -764,10 +763,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
         else:
             lora_mapping = None
 
-        multi_modal_kwargs = {
-            k: torch.cat(v, dim=0).to(self.device)
-            for k, v in multi_modal_kwargs_list.items()
-        }
+        multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
+                                                    device=self.device)
         request_ids_to_seq_ids = {
             seq_group_metadata.request_id:
             list(seq_group_metadata.seq_data.keys())
@@ -841,7 +838,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
 
             seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
                 .dummy_data_for_profiling(model_config, seq_len)
-            assert len(seq_data.prompt_token_ids) == seq_len
+
+            # Having more tokens is over-conservative but otherwise fine
+            assert len(seq_data.prompt_token_ids) >= seq_len, (
+                f"Expected at least {seq_len} dummy tokens for profiling, "
+                f"but got: {len(seq_data.prompt_token_ids)}")
 
             seq = SequenceGroupMetadata(
                 request_id=str(group_id),

+ 29 - 6
aphrodite/task_handler/neuron_model_runner.py

@@ -1,5 +1,6 @@
 from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
+                    Union)
 
 import torch
 from loguru import logger
@@ -13,6 +14,8 @@ from aphrodite.common.utils import (is_pin_memory_available,
                                     make_tensor_with_pad)
 from aphrodite.modeling import SamplingMetadata
 from aphrodite.modeling.model_loader.neuron import get_neuron_model
+from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
+                                  MultiModalInputs)
 from aphrodite.task_handler.model_runner_base import (ModelRunnerBase,
                                                       ModelRunnerInputBase)
 
@@ -29,6 +32,7 @@ class ModelInputForNeuron(ModelRunnerInputBase):
     input_positions: Optional[torch.Tensor] = None
     input_block_ids: Optional[torch.Tensor] = None
     sampling_metadata: Optional["SamplingMetadata"] = None
+    multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
 
     def as_broadcastable_tensor_dict(
             self) -> Dict[str, Union[int, torch.Tensor]]:
@@ -65,6 +69,10 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
         self.device = self.device_config.device
         self.pin_memory = is_pin_memory_available()
 
+        # Multi-modal data support
+        self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
+            .create_input_mapper(self.model_config)
+
         # Lazy initialization.
         self.model: nn.Module  # initialize after load_model.
 
@@ -76,13 +84,15 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
     def _prepare_prompt(
         self,
         seq_group_metadata_list: List[SequenceGroupMetadata],
-    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int], Mapping[
+            str, BatchedTensors]]:
         assert len(seq_group_metadata_list) > 0
         input_tokens: List[List[int]] = []
         input_positions: List[List[int]] = []
         input_block_ids: List[int] = []
 
         seq_lens: List[int] = []
+        multi_modal_inputs_list: List[MultiModalInputs] = []
         for seq_group_metadata in seq_group_metadata_list:
             assert seq_group_metadata.is_prompt
             seq_ids = list(seq_group_metadata.seq_data.keys())
@@ -102,6 +112,12 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
             assert len(block_table) == 1
             input_block_ids.append(block_table[0])
 
+            mm_data = seq_group_metadata.multi_modal_data
+            if mm_data:
+                # Process multi-modal data
+                mm_kwargs = self.multi_modal_input_mapper(mm_data)
+                multi_modal_inputs_list.append(mm_kwargs)
+
         max_seq_len = max(seq_lens)
         assert max_seq_len > 0
         input_tokens = make_tensor_with_pad(input_tokens,
@@ -118,7 +134,11 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
                                        dtype=torch.long,
                                        device=self.device)
 
-        return input_tokens, input_positions, input_block_ids, seq_lens
+        multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
+                                                    device=self.device)
+
+        return (input_tokens, input_positions, input_block_ids, seq_lens,
+                multi_modal_kwargs)
 
     def _prepare_decode(
         self,
@@ -184,8 +204,9 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
         is_prompt = seq_group_metadata_list[0].is_prompt
         # Prepare input tensors.
         if is_prompt:
-            (input_tokens, input_positions, input_block_ids,
-             seq_lens) = self._prepare_prompt(seq_group_metadata_list)
+            (input_tokens, input_positions, input_block_ids, seq_lens,
+             multi_modal_kwargs
+             ) = self._prepare_prompt(seq_group_metadata_list)
         else:
             (input_tokens, input_positions,
              input_block_ids) = self._prepare_decode(seq_group_metadata_list)
@@ -203,7 +224,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
         return ModelInputForNeuron(input_tokens=input_tokens,
                                    input_positions=input_positions,
                                    input_block_ids=input_block_ids,
-                                   sampling_metadata=sampling_metadata)
+                                   sampling_metadata=sampling_metadata,
+                                   multi_modal_kwargs=multi_modal_kwargs)
 
     @torch.inference_mode()
     def execute_model(
@@ -221,6 +243,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
             input_ids=model_input.input_tokens,
             positions=model_input.input_positions,
             input_block_ids=model_input.input_block_ids,
+            **(model_input.multi_modal_kwargs or {}),
         )
 
         # Compute the logits.

+ 25 - 12
aphrodite/task_handler/openvino_model_runner.py

@@ -1,4 +1,4 @@
-from typing import List, NamedTuple, Optional, Tuple
+from typing import List, Mapping, NamedTuple, Optional, Tuple
 
 import openvino as ov
 import torch
@@ -12,6 +12,8 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
 from aphrodite.modeling import SamplingMetadata
 from aphrodite.modeling.model_loader.openvino import get_model
+from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
+                                  MultiModalInputs)
 
 
 class ModelInput(NamedTuple):
@@ -20,7 +22,7 @@ class ModelInput(NamedTuple):
     attn_metadata: Optional[OpenVINOAttentionMetadata]
     seq_lens: List[int]
     query_lens: List[int]
-    multi_modal_input: Optional[torch.Tensor]
+    multi_modal_kwargs: Mapping[str, BatchedTensors]
 
     @classmethod
     def empty(cls, device):
@@ -29,7 +31,7 @@ class ModelInput(NamedTuple):
                           attn_metadata=None,
                           seq_lens=[],
                           query_lens=[],
-                          multi_modal_input=None)
+                          multi_modal_kwargs={})
 
 
 class OpenVINOModelRunner:
@@ -75,6 +77,10 @@ class OpenVINOModelRunner:
             self.block_size,
         )
 
+        # Multi-modal data support
+        self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
+            .create_input_mapper(self.model_config)
+
         # Lazy initialization.
         self.model: nn.Module  # Set after init_Model
 
@@ -102,6 +108,7 @@ class OpenVINOModelRunner:
         seq_lens: List[int] = []
         past_lens: List[int] = []
         query_lens: List[int] = []
+        multi_modal_inputs_list: List[MultiModalInputs] = []
         subsequence_begins: List[int] = []
         block_indices: List[int] = []
         block_indices_begins: List[int] = []
@@ -154,6 +161,11 @@ class OpenVINOModelRunner:
                                     and self.sliding_window is None
                                     and is_prompt)
 
+                mm_data = seq_group_metadata.multi_modal_data
+                if mm_data:
+                    mm_kwargs = self.multi_modal_input_mapper(mm_data)
+                    multi_modal_inputs_list.append(mm_kwargs)
+
                 block_table = seq_group_metadata.block_tables[seq_id]
                 # TODO: Combine chunked prefill and prefix caching by
                 # only allowing multiple of block_size chunk size.
@@ -245,22 +257,24 @@ class OpenVINOModelRunner:
             block_indices_begins=block_indices_begins_tensor,
             max_context_len=max_context_len_tensor,
         )
+
+        multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
+                                                    device=self.device)
+
         return ModelInput(
             input_tokens,
             input_positions,
             attn_metadata,
             seq_lens,
             query_lens,
-            None,
+            multi_modal_kwargs=multi_modal_kwargs,
         )
 
     def prepare_input_tensors(
         self,
         seq_group_metadata_list: List[SequenceGroupMetadata],
     ) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata,
-               SamplingMetadata, Optional[torch.Tensor], ]:
-        multi_modal_input = None
-
+               SamplingMetadata, Mapping[str, BatchedTensors]]:
         # Prepare input tensors.
         (
             input_tokens,
@@ -268,7 +282,7 @@ class OpenVINOModelRunner:
             attn_metadata,
             seq_lens,
             query_lens,
-            multi_modal_input,
+            multi_modal_kwargs,
         ) = self._prepare_model_input(seq_group_metadata_list)
 
         sampling_metadata = SamplingMetadata.prepare(
@@ -284,7 +298,7 @@ class OpenVINOModelRunner:
             input_positions,
             attn_metadata,
             sampling_metadata,
-            multi_modal_input,
+            multi_modal_kwargs,
         )
 
     @torch.inference_mode()
@@ -298,7 +312,7 @@ class OpenVINOModelRunner:
             input_positions,
             attn_metadata,
             sampling_metadata,
-            multi_modal_input,
+            multi_modal_kwargs,
         ) = self.prepare_input_tensors(seq_group_metadata_list)
 
         model_executable = self.model
@@ -307,9 +321,8 @@ class OpenVINOModelRunner:
             "positions": input_positions,
             "kv_caches": kv_caches,
             "attn_metadata": attn_metadata,
+            **(multi_modal_kwargs or {}),
         }
-        if self.vision_language_config:
-            execute_model_kwargs.update({"image_input": multi_modal_input})
 
         hidden_states = model_executable(**execute_model_kwargs)
 

+ 39 - 5
aphrodite/task_handler/tpu_model_runner.py

@@ -1,5 +1,5 @@
 import time
-from typing import List, Optional, Tuple
+from typing import List, Mapping, Optional, Tuple
 
 import numpy as np
 import torch
@@ -17,6 +17,8 @@ from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
 from aphrodite.common.utils import make_tensor_with_pad
 from aphrodite.modeling.model_loader import get_model
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
+                                  MultiModalInputs)
 
 _PAD_SLOT_ID = -1  # NOTE: In PyTorch XLA, index -1 is ignored
 # FIXME: Temporarily disabled top-p sampling since it's too slow.
@@ -65,6 +67,10 @@ class TPUModelRunner:
             False,
         )
 
+        # Multi-modal data support
+        self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
+            .create_input_mapper(self.model_config)
+
     def load_model(self) -> None:
         self.device = self.device_config.device
 
@@ -192,12 +198,14 @@ class TPUModelRunner:
     def _prepare_prompt(
         self,
         seq_group_metadata_list: List[SequenceGroupMetadata],
-    ):
+    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
+               Mapping[str, BatchedTensors]]:
         assert len(seq_group_metadata_list) > 0
         input_tokens: List[List[int]] = []
         input_positions: List[List[int]] = []
         prompt_lens: List[int] = []
         slot_mapping: List[List[int]] = []
+        multi_modal_inputs_list: List[MultiModalInputs] = []
 
         for seq_group_metadata in seq_group_metadata_list:
             assert seq_group_metadata.is_prompt
@@ -223,6 +231,11 @@ class TPUModelRunner:
                 slot = block_number * self.block_size + block_offset
                 slot_mapping[-1].append(slot)
 
+            mm_data = seq_group_metadata.multi_modal_data
+            if mm_data:
+                mm_kwargs = self.multi_modal_input_mapper(mm_data)
+                multi_modal_inputs_list.append(mm_kwargs)
+
         assert len(prompt_lens) > 0
         num_prefills = len(prompt_lens)
         num_prefill_tokens = sum(prompt_lens)
@@ -260,17 +273,24 @@ class TPUModelRunner:
             block_tables=None,
             context_lens=None,
         )
-        return input_tokens, input_positions, attn_metadata, prompt_lens
+
+        multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
+                                                    device=self.device)
+
+        return (input_tokens, input_positions, attn_metadata, prompt_lens,
+                multi_modal_kwargs)
 
     def _prepare_decode(
         self,
         seq_group_metadata_list: List[SequenceGroupMetadata],
-    ):
+    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
+               Mapping[str, BatchedTensors]]:
         assert len(seq_group_metadata_list) > 0
         input_tokens: List[List[int]] = []
         input_positions: List[List[int]] = []
         slot_mapping: List[List[int]] = []
         context_lens: List[int] = []
+        multi_modal_inputs_list: List[MultiModalInputs] = []
 
         batch_idx = 0
         for seq_group_metadata in seq_group_metadata_list:
@@ -298,6 +318,11 @@ class TPUModelRunner:
                 slot = block_number * self.block_size + block_offset
                 slot_mapping.append([slot])
 
+            mm_data = seq_group_metadata.multi_modal_data
+            if mm_data:
+                mm_kwargs = self.multi_modal_input_mapper(mm_data)
+                multi_modal_inputs_list.append(mm_kwargs)
+
         batch_size = _get_padded_batch_size(batch_idx)
         num_paddings = batch_size - batch_idx
         input_tokens = input_tokens + [[0]] * num_paddings
@@ -331,7 +356,12 @@ class TPUModelRunner:
             block_tables=block_tables,
             context_lens=context_lens,
         )
-        return input_tokens, input_positions, attn_metadata, input_lens
+
+        multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
+                                                    device=self.device)
+
+        return (input_tokens, input_positions, attn_metadata, input_lens,
+                multi_modal_kwargs)
 
     def _prepare_sample(
         self,
@@ -484,6 +514,7 @@ class ModelWrapper(nn.Module):
         kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
         attn_metadata: AttentionMetadata,
         input_lens: torch.Tensor,
+        multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]],
         t: torch.Tensor,
         p: torch.Tensor,
         num_samples: int,
@@ -497,6 +528,8 @@ class ModelWrapper(nn.Module):
                 memory profiling at initialization.
             attn_metadata: The Pallas attention metadata.
             input_lens: The actual input lengths of shape [batch_size].
+            multi_modal_kwargs: Keyword arguments from multi-modal data to
+                pass to the model.
             t: The sampling temperature of shape [batch_size].
             p: The top-p probability of shape [batch_size].
         """
@@ -541,6 +574,7 @@ class ModelWrapper(nn.Module):
             position_ids,
             kv_caches,
             attn_metadata,
+            **(multi_modal_kwargs or {}),
         )
         hidden_states = hidden_states.flatten(0, 1)
         logits = self.model.compute_logits(hidden_states, sampling_metadata)

+ 41 - 26
aphrodite/task_handler/xpu_model_runner.py

@@ -1,5 +1,6 @@
 from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
+from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
+                    Type, Union)
 
 import torch
 import torch.nn as nn
@@ -11,10 +12,13 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      SchedulerConfig, VisionLanguageConfig)
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
-                                       SequenceData, SequenceGroupMetadata)
+                                       SequenceGroupMetadata)
 from aphrodite.common.utils import CudaMemoryProfiler, make_tensor_with_pad
 from aphrodite.distributed import broadcast_tensor_dict
+from aphrodite.inputs import INPUT_REGISTRY
 from aphrodite.modeling.model_loader import get_model
+from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
+                                  MultiModalInputs)
 from aphrodite.task_handler.model_runner import (AttentionMetadata,
                                                  SamplingMetadata)
 from aphrodite.task_handler.model_runner_base import (
@@ -43,7 +47,7 @@ class ModelInputForXPU(ModelRunnerInputBase):
     input_positions: Optional[torch.Tensor] = None
     attn_metadata: Optional["AttentionMetadata"] = None
     sampling_metadata: Optional["SamplingMetadata"] = None
-    multi_modal_input: Optional[Dict[str, torch.Tensor]] = None
+    multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
 
     def as_broadcastable_tensor_dict(
             self) -> Dict[str, Union[int, torch.Tensor]]:
@@ -115,6 +119,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
             self.block_size,
         )
 
+        # Multi-modal data support
+        self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
+            .create_input_mapper(self.model_config)
+
         # Lazy initialization.
         self.model: nn.Module  # Set after init_Model
 
@@ -155,12 +163,24 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
         # To exercise the worst scenario for GPU memory consumption,
         # the number of seqs (batch_size) is chosen to maximize the number
         # of images processed.
+        model_config = self.model_config
+        vlm_config = self.vision_language_config
+
+        if vlm_config:
+            max_num_seqs = min(
+                max_num_seqs,
+                int(max_num_batched_tokens / vlm_config.image_feature_size))
         for group_id in range(max_num_seqs):
             seq_len = (max_num_batched_tokens // max_num_seqs +
                        (group_id < max_num_batched_tokens % max_num_seqs))
 
-            seq_data = SequenceData([0] * seq_len)
-            dummy_multi_modal_data = None
+            seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
+                .dummy_data_for_profiling(model_config, seq_len)
+
+            # Having more tokens is over-conservative but otherwise fine
+            assert len(seq_data.prompt_token_ids) >= seq_len, (
+                f"Expected at least {seq_len} dummy tokens for profiling, "
+                f"but got: {len(seq_data.prompt_token_ids)}")
             seq = SequenceGroupMetadata(
                 request_id=str(group_id),
                 is_prompt=True,
@@ -193,7 +213,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
             virtual_engine: int = 0,
             finished_requests_ids: Optional[List[str]] = None
     ) -> ModelInputForXPU:
-        multi_modal_input = None
+        multi_modal_kwargs = None
         if self.is_driver_worker:
             # NOTE: We assume that all sequences in the group are all prompts or
             # all decodes.
@@ -201,7 +221,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
             # Prepare input tensors.
             if is_prompt:
                 (input_tokens, input_positions, attn_metadata, seq_lens,
-                 multi_modal_input
+                 multi_modal_kwargs
                  ) = self._prepare_prompt(seq_group_metadata_list)
             else:
                 (input_tokens, input_positions,
@@ -222,6 +242,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
                 "input_positions": input_positions,
                 "selected_token_indices":
                 sampling_metadata.selected_token_indices,
+                "multi_modal_kwargs": multi_modal_kwargs,
             }
             metadata_dict.update(attn_metadata.asdict_zerocopy())
             broadcast_tensor_dict(metadata_dict, src=0)
@@ -231,6 +252,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
             input_positions = metadata_dict.pop("input_positions")
             selected_token_indices = metadata_dict.pop(
                 "selected_token_indices")
+            multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
             attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
             sampling_metadata = SamplingMetadata(
                 seq_groups=None,
@@ -243,7 +265,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
                                 input_positions=input_positions,
                                 attn_metadata=attn_metadata,
                                 sampling_metadata=sampling_metadata,
-                                multi_modal_input=multi_modal_input)
+                                multi_modal_kwargs=multi_modal_kwargs)
 
     def _prepare_decode(
         self,
@@ -349,10 +371,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
             "positions": model_input.input_positions,
             "kv_caches": kv_caches,
             "attn_metadata": model_input.attn_metadata,
+            **(model_input.multi_modal_kwargs or {}),
         }
-        if self.vision_language_config:
-            execute_model_kwargs.update(
-                {"image_input": model_input.multi_modal_input})
 
         hidden_states = model_executable(**execute_model_kwargs)
 
@@ -375,13 +395,13 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
         self,
         seq_group_metadata_list: List[SequenceGroupMetadata],
     ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
-               Optional[torch.Tensor]]:
+               Mapping[str, BatchedTensors]]:
         assert len(seq_group_metadata_list) > 0
         input_tokens: List[int] = []
         input_positions: List[int] = []
         slot_mapping: List[int] = []
         seq_lens: List[int] = []
-        multi_modal_input_list: List[torch.Tensor] = []
+        multi_modal_inputs_list: List[MultiModalInputs] = []
 
         for seq_group_metadata in seq_group_metadata_list:
             assert seq_group_metadata.is_prompt
@@ -402,9 +422,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
             # is always the first token in the sequence.
             input_positions.extend(list(range(computed_len, seq_len)))
 
-            if seq_group_metadata.multi_modal_data:
-                multi_modal_input_list.append(
-                    seq_group_metadata.multi_modal_data.data)
+            mm_data = seq_group_metadata.multi_modal_data
+            if mm_data:
+                mm_kwargs = self.multi_modal_input_mapper(mm_data)
+                multi_modal_inputs_list.append(mm_kwargs)
 
             if seq_group_metadata.block_tables is None:
                 # During memory profiling, the block tables are not initialized
@@ -434,15 +455,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
                 slot = block_number * self.block_size + block_offset
                 slot_mapping.append(slot)
 
-        if multi_modal_input_list:
-            assert self.vision_language_config, (
-                "Multi-modal inputs are only supported by "
-                "vision language models.")
-            multi_modal_input = torch.cat(multi_modal_input_list,
-                                          dim=0).to(self.device)
-        else:
-            multi_modal_input = None
-
         num_prompt_tokens = len(input_tokens)
 
         input_tokens = torch.tensor(input_tokens,
@@ -474,5 +486,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
             num_decode_tokens=0,
             block_tables=torch.tensor([], device=self.device, dtype=torch.int),
         )
+
+        multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
+                                                    device=self.device)
         return (input_tokens, input_positions, attn_metadata, seq_lens,
-                multi_modal_input)
+                multi_modal_kwargs)

+ 9 - 5
aphrodite/transformers_utils/image_processor.py

@@ -1,5 +1,4 @@
-from transformers import AutoImageProcessor
-from transformers.image_processing_utils import BaseImageProcessor
+from typing import cast
 
 
 def get_image_processor(
@@ -7,10 +6,15 @@ def get_image_processor(
     *args,
     trust_remote_code: bool = False,
     **kwargs,
-) -> BaseImageProcessor:
+):
     """Gets an image processor for the given model name via HuggingFace."""
+    # don't put this import at the top level
+    # it will call torch.cuda.device_count()
+    from transformers import AutoImageProcessor
+    from transformers.image_processing_utils import BaseImageProcessor
+
     try:
-        processor: BaseImageProcessor = AutoImageProcessor.from_pretrained(
+        processor = AutoImageProcessor.from_pretrained(
             processor_name,
             *args,
             trust_remote_code=trust_remote_code,
@@ -30,4 +34,4 @@ def get_image_processor(
         else:
             raise e
 
-    return processor
+    return cast(BaseImageProcessor, processor)

BIN
examples/vision/burg.jpg


+ 17 - 28
examples/vision/llava_example.py

@@ -1,9 +1,8 @@
 import os
-import subprocess
 
 from PIL import Image
 
-from aphrodite import LLM
+from aphrodite import LLM, SamplingParams
 
 # The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
 # You can use `.buildkite/download-images.sh` to download them
@@ -17,17 +16,23 @@ def run_llava():
         image_feature_size=576,
     )
 
-    prompt = "<image>" * 576 + (
-        "\nUSER: What is the content of this image?\nASSISTANT:")
-
-    image = Image.open("images/stop_sign.jpg")
-
-    outputs = llm.generate({
-        "prompt": prompt,
-        "multi_modal_data": {
-            "image": image
+    prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
+    image_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
+                              "burg.jpg")
+    image = Image.open(image_path)
+
+    sampling_params = SamplingParams(temperature=1.1,
+                                     min_p=0.06,
+                                     max_tokens=512)
+
+    outputs = llm.generate(
+        {
+            "prompt": prompt,
+            "multi_modal_data": {
+                "image": image
+            }
         },
-    })
+        sampling_params=sampling_params)
 
     for o in outputs:
         generated_text = o.outputs[0].text
@@ -39,20 +44,4 @@ def main():
 
 
 if __name__ == "__main__":
-    # Download from s3
-    s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
-    local_directory = "images"
-
-    # Make sure the local directory exists or create it
-    os.makedirs(local_directory, exist_ok=True)
-
-    # Use AWS CLI to sync the directory, assume anonymous access
-    subprocess.check_call([
-        "aws",
-        "s3",
-        "sync",
-        s3_bucket_path,
-        local_directory,
-        "--no-sign-request",
-    ])
     main()

+ 10 - 17
examples/vision/llava_next_example.py

@@ -1,30 +1,25 @@
-from io import BytesIO
-
-import requests
+import os
 from PIL import Image
 
 from aphrodite import LLM, SamplingParams
 
-# Dynamic image input is currently not supported and therefore
-# a fixed image input shape and its corresponding feature size is required.
-
 
 def run_llava_next():
     llm = LLM(
         model="llava-hf/llava-v1.6-mistral-7b-hf",
         image_token_id=32000,
         image_input_shape="1,3,336,336",
-        image_feature_size=1176,
+        # Use the maximum possible value for memory profiling
+        image_feature_size=2928,
     )
 
-    prompt = "[INST] " + "<image>" * 1176 + (
-        "\nWhat is shown in this image? [/INST]")
-    url = "https://h2o-release.s3.amazonaws.com/h2ogpt/bigben.jpg"
-    image = Image.open(BytesIO(requests.get(url).content))
-    sampling_params = SamplingParams(temperature=0.8,
-                                     top_p=0.95,
-                                     max_tokens=100)
-
+    prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
+    image_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
+                              "burg.jpg")
+    image = Image.open(image_path)
+    sampling_params = SamplingParams(temperature=1.1,
+                                     min_p=0.06,
+                                     max_tokens=512)
     outputs = llm.generate(
         {
             "prompt": prompt,
@@ -33,11 +28,9 @@ def run_llava_next():
             }
         },
         sampling_params=sampling_params)
-
     generated_text = ""
     for o in outputs:
         generated_text += o.outputs[0].text
-
     print(f"LLM output:{generated_text}")
 
 

+ 9 - 21
examples/vision/phi3v_example.py

@@ -1,5 +1,4 @@
 import os
-import subprocess
 
 from PIL import Image
 
@@ -18,17 +17,21 @@ def run_phi3v():
         trust_remote_code=True,
         image_token_id=32044,
         image_input_shape="1,3,1008,1344",
-        image_feature_size=1921,
+        # Use the maximum possible value for memory profiling
+        image_feature_size=2653,
         max_num_seqs=5,
     )
 
-    image = Image.open("images/cherry_blossom.jpg")
+    image_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
+                              "burg.jpg")
+    image = Image.open(image_path)
 
     # single-image prompt
-    prompt = "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n"  # noqa: E501
-    prompt = prompt.replace("<|image_1|>", "<|image|>" * 1921 + "<s>")
+    prompt = "<|user|>\n<|image_1|>\nWhat is shown in this image?<|end|>\n<|assistant|>\n"  # noqa: E501
 
-    sampling_params = SamplingParams(temperature=0, max_tokens=64)
+    sampling_params = SamplingParams(temperature=1.1,
+                                     min_p=0.06,
+                                     max_tokens=512)
 
     outputs = llm.generate(
         {
@@ -44,19 +47,4 @@ def run_phi3v():
 
 
 if __name__ == "__main__":
-    s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
-    local_directory = "images"
-
-    # Make sure the local directory exists or create it
-    os.makedirs(local_directory, exist_ok=True)
-
-    # Use AWS CLI to sync the directory, assume anonymous access
-    subprocess.check_call([
-        "aws",
-        "s3",
-        "sync",
-        s3_bucket_path,
-        local_directory,
-        "--no-sign-request",
-    ])
     run_phi3v()