浏览代码

feat: allow image embeddings for VLM input (#686)

AlpinDale 6 月之前
父节点
当前提交
2573b36f6a

+ 48 - 23
aphrodite/modeling/models/blip2.py

@@ -1,4 +1,4 @@
-from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
+from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
 
 import torch
 import torch.nn as nn
@@ -29,6 +29,28 @@ _KEYS_TO_MODIFY_MAPPING = {
     "language_model.model": "language_model",
 }
 
+# We use this internally as placeholders since there is no image token
+# defined on the HuggingFace repo
+BLIP2_IMAGE_TOKEN = "<image>"
+BLIP2_IMAGE_TOKEN_ID = 50265
+
+
+class Blip2ImagePixelInputs(TypedDict):
+    type: Literal["pixel_values"]
+    data: torch.Tensor
+    """Shape: (batch_size, num_channels, height, width)"""
+
+
+class Blip2ImageEmbeddingInputs(TypedDict):
+    type: Literal["image_embeds"]
+    data: torch.Tensor
+    """Shape: `(batch_size, image_feature_size, hidden_size)`
+    `hidden_size` must match the hidden size of language model backbone.
+    """
+
+
+Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs]
+
 
 class Blip2QFormerMultiHeadAttention(nn.Module):
 
@@ -376,20 +398,6 @@ class Blip2QFormerModel(nn.Module):
         return sequence_output
 
 
-class Blip2ImagePixelInputs(TypedDict):
-    type: Literal["pixel_values"]
-    data: torch.Tensor
-    """Shape: (batch_size, num_channels, height, width)"""
-
-
-Blip2ImageInputs = Blip2ImagePixelInputs
-
-# We use this internally as placeholders since there is no image token
-# defined on the HuggingFace repo
-BLIP2_IMAGE_TOKEN = "<image>"
-BLIP2_IMAGE_TOKEN_ID = 50265
-
-
 def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
     return hf_config.num_query_tokens
 
@@ -507,18 +515,31 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
     def _parse_and_validate_image_input(
             self, **kwargs: object) -> Optional[Blip2ImageInputs]:
         pixel_values = kwargs.pop("pixel_values", None)
+        image_embeds = kwargs.pop("image_embeds", None)
 
-        if pixel_values is None:
+        if pixel_values is None and image_embeds is None:
             return None
 
-        if not isinstance(pixel_values, torch.Tensor):
-            raise ValueError("Incorrect type of pixel values. "
-                             f"Got type: {type(pixel_values)}")
+        if pixel_values is not None:
+            if not isinstance(pixel_values, torch.Tensor):
+                raise ValueError("Incorrect type of pixel values. "
+                                 f"Got type: {type(pixel_values)}")
 
-        return Blip2ImagePixelInputs(
-            type="pixel_values",
-            data=self._validate_pixel_values(pixel_values),
-        )
+            return Blip2ImagePixelInputs(
+                type="pixel_values",
+                data=self._validate_pixel_values(pixel_values),
+            )
+
+        if image_embeds is not None:
+            if not isinstance(image_embeds, torch.Tensor):
+                raise ValueError("Incorrect type of image embeddings. "
+                                 f"Got type: {type(image_embeds)}")
+            return Blip2ImageEmbeddingInputs(
+                type="image_embeds",
+                data=image_embeds,
+            )
+
+        raise AssertionError("This line should be unreachable.")
 
     def _image_pixels_to_features(self, vision_model: BlipVisionModel,
                                   pixel_values: torch.Tensor) -> torch.Tensor:
@@ -539,6 +560,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
 
     def _process_image_input(self,
                              image_input: Blip2ImageInputs) -> torch.Tensor:
+
+        if image_input["type"] == "image_embeds":
+            return image_input["data"]
+
         assert self.vision_model is not None
         image_features = self._process_image_pixels(image_input)
 

+ 7 - 1
aphrodite/modeling/models/clip.py

@@ -87,7 +87,13 @@ def input_processor_for_clip(
     tokenizer = cached_get_tokenizer(model_config.tokenizer)
 
     if image_feature_size_override is None:
-        image_feature_size = get_clip_image_feature_size(hf_config)
+        image_data = multi_modal_data["image"]
+        if isinstance(image_data, Image.Image):
+            image_feature_size = get_clip_image_feature_size(hf_config)
+        elif isinstance(image_data, torch.Tensor):
+            image_feature_size = image_data.shape[0]
+        else:
+            raise TypeError(f"Invalid image type: {type(image_data)}")
     else:
         image_feature_size = image_feature_size_override
 

+ 10 - 3
aphrodite/modeling/models/fuyu.py

@@ -231,7 +231,8 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
                                                    cache_config=cache_config,
                                                    quant_config=quant_config)
 
-    def _parse_and_validate_image_input(self, **kwargs: object):
+    def _parse_and_validate_image_input(
+            self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
         image_patches = kwargs.pop("image_patches", None)
 
         if isinstance(image_patches, torch.Tensor):
@@ -246,6 +247,13 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
                                         data=image_patches)
         return None
 
+    def _process_image_input(
+            self, image_input: FuyuImagePixelInputs) -> torch.Tensor:
+
+        assert self.vision_embed_tokens is not None
+        vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
+        return vision_embeddings
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -258,8 +266,7 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
         image_input = self._parse_and_validate_image_input(**kwargs)
 
         if image_input is not None:
-            vision_embeddings, _ = self.vision_embed_tokens(
-                image_input["data"])
+            vision_embeddings = self._process_image_input(image_input)
             inputs_embeds = self.language_model.model.embed_tokens(input_ids)
             inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
                                                     vision_embeddings,

+ 53 - 13
aphrodite/modeling/models/internvl.py

@@ -50,6 +50,18 @@ class InternVLImagePixelInputs(TypedDict):
     """
 
 
+class InternVLImageEmbeddingInputs(TypedDict):
+    type: Literal["image_embeds"]
+    data: Union[torch.Tensor, List[torch.Tensor]]
+    """Shape: `(batch_size, image_feature_size, hidden_size)`
+    `hidden_size` must match the hidden size of language model backbone.
+    """
+
+
+InternVLImageInputs = Union[InternVLImagePixelInputs,
+                            InternVLImageEmbeddingInputs]
+
+
 # copied from https://huggingface.co/OpenGVLab/InternVL2-1B
 def build_transform(input_size):
     MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
@@ -191,8 +203,10 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
         # add thumbnail image if num_blocks > 1
         if hf_config.use_thumbnail and num_blocks > 1:
             num_blocks += 1
+        image_feature_size = num_blocks * num_patches
+
     elif isinstance(image_data, torch.Tensor):
-        raise NotImplementedError("Embeddings input is not supported yet")
+        image_feature_size = image_data.shape[0]
     else:
         raise TypeError(f"Invalid image type: {type(image_data)}")
 
@@ -203,7 +217,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
     prompt_token_ids = llm_inputs["prompt_token_ids"]
     if prompt is None:
         prompt = tokenizer.decode(prompt_token_ids)
-    image_prompt = IMG_START + IMG_CONTEXT * num_blocks * num_patches + IMG_END
+    image_prompt = IMG_START + IMG_CONTEXT * image_feature_size + IMG_END
     new_prompt = prompt.replace('<image>', image_prompt, 1)
     new_prompt_token_ids = tokenizer.encode(new_prompt)
 
@@ -374,23 +388,49 @@ class InternVLChatModel(nn.Module, SupportsVision):
         return data
 
     def _parse_and_validate_image_input(
-            self, **kwargs: object) -> Optional[InternVLImagePixelInputs]:
+            self, **kwargs: object) -> Optional[InternVLImageInputs]:
         pixel_values = kwargs.pop("pixel_values", None)
         image_token_id = kwargs.pop("image_token_id", None)
+        image_embeds = kwargs.pop("image_embeds", None)
 
-        if pixel_values is None:
+        if pixel_values is None and image_embeds is None:
             return None
 
+        if image_embeds is not None:
+            if not isinstance(image_embeds, torch.Tensor):
+                raise ValueError("Incorrect type of image embeddings. "
+                                 f"Got type: {type(image_embeds)}")
+            return InternVLImageEmbeddingInputs(
+                type="image_embeds",
+                data=image_embeds,
+            )
+
         self.img_context_token_id = image_token_id[0]
 
-        if not isinstance(pixel_values, (torch.Tensor, list)):
-            raise ValueError("Incorrect type of pixel values. "
-                             f"Got type: {type(pixel_values)}")
+        if pixel_values is not None:
+            if not isinstance(pixel_values, (torch.Tensor, list)):
+                raise ValueError("Incorrect type of pixel values. "
+                                 f"Got type: {type(pixel_values)}")
+
+            return InternVLImagePixelInputs(
+                type="pixel_values",
+                data=self._validate_pixel_values(pixel_values),
+            )
+
+        raise AssertionError("This line should be unreachable.")
+
+    def _process_image_input(
+        self,
+        image_input: InternVLImageInputs,
+    ) -> torch.Tensor:
+
+        if image_input["type"] == "image_embeds":
+            return image_input["data"]
+
+        assert self.vision_model is not None
+        image_embeds = self.extract_feature(image_input["data"])
 
-        return InternVLImagePixelInputs(
-            type="pixel_values",
-            data=self._validate_pixel_values(pixel_values),
-        )
+        return image_embeds
 
     def forward(
         self,
@@ -405,9 +445,9 @@ class InternVLChatModel(nn.Module, SupportsVision):
         if image_input is not None:
             inputs_embeds = self.language_model.model.get_input_embeddings(
                 input_ids)
-            vit_embeds = self.extract_feature(image_input["data"])
+            vision_embeddings = self._process_image_input(image_input)
             inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
-                                                    vit_embeds,
+                                                    vision_embeddings,
                                                     self.img_context_token_id)
             input_ids = None
         else:

+ 42 - 18
aphrodite/modeling/models/llava.py

@@ -26,6 +26,23 @@ from .utils import (filter_weights, init_aphrodite_registered_model,
                     merge_vision_embeddings)
 
 
+class LlavaImagePixelInputs(TypedDict):
+    type: Literal["pixel_values"]
+    data: torch.Tensor
+    """Shape: `(batch_size, num_channels, height, width)`"""
+
+
+class LlavaImageEmbeddingInputs(TypedDict):
+    type: Literal["image_embeds"]
+    data: torch.Tensor
+    """Shape: `(batch_size, image_feature_size, hidden_size)`
+    `hidden_size` must match the hidden size of language model backbone.
+    """
+
+
+LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
+
+
 # TODO: Run benchmark and decide if TP.
 class LlavaMultiModalProjector(nn.Module):
 
@@ -48,15 +65,6 @@ class LlavaMultiModalProjector(nn.Module):
         return hidden_states
 
 
-class LlavaImagePixelInputs(TypedDict):
-    type: Literal["pixel_values"]
-    data: torch.Tensor
-    """Shape: `(batch_size, num_channels, height, width)`"""
-
-
-LlavaImageInputs = LlavaImagePixelInputs
-
-
 def get_max_llava_image_tokens(ctx: InputContext):
     hf_config = ctx.get_hf_config(LlavaConfig)
     vision_config = hf_config.vision_config
@@ -209,18 +217,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
     def _parse_and_validate_image_input(
             self, **kwargs: object) -> Optional[LlavaImageInputs]:
         pixel_values = kwargs.pop("pixel_values", None)
+        image_embeds = kwargs.pop("image_embeds", None)
 
-        if pixel_values is None:
+        if pixel_values is None and image_embeds is None:
             return None
 
-        if not isinstance(pixel_values, torch.Tensor):
-            raise ValueError("Incorrect type of pixel values. "
-                             f"Got type: {type(pixel_values)}")
-
-        return LlavaImagePixelInputs(
-            type="pixel_values",
-            data=self._validate_pixel_values(pixel_values),
-        )
+        if pixel_values is not None:
+            if not isinstance(pixel_values, torch.Tensor):
+                raise ValueError("Incorrect type of pixel values. "
+                                 f"Got type: {type(pixel_values)}")
+            return LlavaImagePixelInputs(
+                type="pixel_values",
+                data=self._validate_pixel_values(pixel_values),
+            )
+
+        if image_embeds is not None:
+            if not isinstance(image_embeds, torch.Tensor):
+                raise ValueError("Incorrect type of image embeddings. "
+                                 f"Got type: {type(image_embeds)}")
+            return LlavaImageEmbeddingInputs(
+                type="image_embeds",
+                data=image_embeds,
+            )
+
+        raise AssertionError("This line should be unreachable.")
 
     def _select_image_features(self, image_features: torch.Tensor, *,
                                strategy: str) -> torch.Tensor:
@@ -257,6 +277,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
 
     def _process_image_input(self,
                              image_input: LlavaImageInputs) -> torch.Tensor:
+
+        if image_input["type"] == "image_embeds":
+            return image_input["data"]
+
         assert self.vision_tower is not None
         image_features = self._process_image_pixels(image_input)
         return self.multi_modal_projector(image_features)

+ 42 - 15
aphrodite/modeling/models/llava_next.py

@@ -56,7 +56,16 @@ class LlavaNextImagePixelInputs(TypedDict):
     """
 
 
-LlavaNextImageInputs = LlavaNextImagePixelInputs
+class LlavaNextImageEmbeddingInputs(TypedDict):
+    type: Literal["image_embeds"]
+    data: torch.Tensor
+    """Shape: `(batch_size, image_feature_size, hidden_size)`
+    `hidden_size` must match the hidden size of language model backbone.
+    """
+
+
+LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
+                             LlavaNextImageEmbeddingInputs]
 
 
 # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
@@ -204,7 +213,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
             input_width=width,
         )
     elif isinstance(image_data, torch.Tensor):
-        raise NotImplementedError("Embeddings input is not supported yet")
+        image_feature_size = image_data.shape[0]
     else:
         raise TypeError(f"Invalid image type: {type(image_data)}")
 
@@ -316,26 +325,40 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
         return data
 
     def _parse_and_validate_image_input(
-            self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
+            self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
         pixel_values = kwargs.pop("pixel_values", None)
         image_sizes = kwargs.pop("image_sizes", None)
+        image_embeds = kwargs.pop("image_embeds", None)
 
-        if pixel_values is None:
+        if pixel_values is None and image_embeds is None:
             return None
 
-        if not isinstance(pixel_values, (torch.Tensor, list)):
-            raise ValueError("Incorrect type of pixel values. "
-                             f"Got type: {type(pixel_values)}")
+        if pixel_values is not None:
+            if not isinstance(pixel_values, (torch.Tensor, list)):
+                raise ValueError("Incorrect type of pixel values. "
+                                 f"Got type: {type(pixel_values)}")
 
-        if not isinstance(image_sizes, torch.Tensor):
-            raise ValueError("Incorrect type of image sizes. "
-                             f"Got type: {type(image_sizes)}")
+            if not isinstance(image_sizes, torch.Tensor):
+                raise ValueError("Incorrect type of image sizes. "
+                                 f"Got type: {type(image_sizes)}")
 
-        return LlavaNextImagePixelInputs(
-            type="pixel_values",
-            data=self._validate_pixel_values(pixel_values),
-            image_sizes=self._validate_image_sizes(image_sizes),
-        )
+            return LlavaNextImagePixelInputs(
+                type="pixel_values",
+                data=self._validate_pixel_values(pixel_values),
+                image_sizes=self._validate_image_sizes(image_sizes),
+            )
+
+        if image_embeds is not None:
+            if not isinstance(image_embeds, torch.Tensor):
+                raise ValueError("Incorrect type of image embeds. "
+                                 f"Got type: {type(image_embeds)}")
+
+            return LlavaNextImageEmbeddingInputs(
+                type="image_embeds",
+                data=image_embeds,
+            )
+
+        raise AssertionError("This line should be unreachable.")
 
     def _select_image_features(self, image_features: torch.Tensor, *,
                                strategy: str) -> torch.Tensor:
@@ -462,6 +485,10 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
         self,
         image_input: LlavaNextImageInputs,
     ) -> Union[torch.Tensor, List[torch.Tensor]]:
+
+        if image_input["type"] == "image_embeds":
+            return [image_input["data"]]
+
         patch_embeddings = self._process_image_pixels(image_input)
 
         image_sizes = image_input.get("image_sizes")

+ 47 - 32
aphrodite/modeling/models/paligemma.py

@@ -1,4 +1,4 @@
-from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
+from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
 
 import torch
 from loguru import logger
@@ -28,6 +28,24 @@ _KEYS_TO_MODIFY_MAPPING = {
 }
 
 
+class PaliGemmaImagePixelInputs(TypedDict):
+    type: Literal["pixel_values"]
+    data: torch.Tensor
+    """Shape: (batch_size, num_channels, height, width)"""
+
+
+class PaliGemmaImageEmbeddingInputs(TypedDict):
+    type: Literal["image_embeds"]
+    data: torch.Tensor
+    """Shape: `(batch_size, image_feature_size, hidden_size)`
+    `hidden_size` must match the hidden size of language model backbone.
+    """
+
+
+PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
+                             PaliGemmaImageEmbeddingInputs]
+
+
 def get_max_paligemma_image_tokens(ctx: InputContext):
     hf_config = ctx.get_hf_config(PaliGemmaConfig)
     vision_config = hf_config.vision_config
@@ -103,15 +121,6 @@ class PaliGemmaMultiModalProjector(nn.Module):
         return hidden_states
 
 
-class PaliGemmaImagePixelInputs(TypedDict):
-    type: Literal["pixel_values"]
-    data: torch.Tensor
-    """Shape: (batch_size, num_channels, height, width)"""
-
-
-PaliGemmaImageInputs = PaliGemmaImagePixelInputs
-
-
 @MULTIMODAL_REGISTRY.register_image_input_mapper()
 @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
 @INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
@@ -159,18 +168,30 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
     def _parse_and_validate_image_input(
             self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
         pixel_values = kwargs.pop("pixel_values", None)
+        image_embeds = kwargs.pop("image_embeds", None)
 
-        if pixel_values is None:
+        if pixel_values is None and image_embeds is None:
             return None
 
-        if not isinstance(pixel_values, torch.Tensor):
-            raise ValueError("Incorrect type of pixel values. "
-                             f"Got type: {type(pixel_values)}")
-
-        return PaliGemmaImagePixelInputs(
-            type="pixel_values",
-            data=self._validate_pixel_values(pixel_values),
-        )
+        if pixel_values is not None:
+            if not isinstance(pixel_values, torch.Tensor):
+                raise ValueError("Incorrect type of pixel values. "
+                                 f"Got type: {type(pixel_values)}")
+            return PaliGemmaImagePixelInputs(
+                type="pixel_values",
+                data=self._validate_pixel_values(pixel_values),
+            )
+
+        if image_embeds is not None:
+            if not isinstance(image_embeds, torch.Tensor):
+                raise ValueError("Incorrect type of image embeddings. "
+                                 f"Got type: {type(image_embeds)}")
+            return PaliGemmaImageEmbeddingInputs(
+                type="image_embeds",
+                data=image_embeds,
+            )
+
+        raise AssertionError("This line should be unreachable.")
 
     def _image_pixels_to_features(
         self,
@@ -183,27 +204,21 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
 
         return image_features
 
-    def _process_image_pixels(
+    def _process_image_input(
         self,
-        inputs: PaliGemmaImagePixelInputs,
+        image_input: PaliGemmaImageInputs,
     ) -> torch.Tensor:
-        assert self.vision_tower is not None
 
-        pixel_values = inputs["data"]
+        if image_input["type"] == "image_embeds":
+            return image_input["data"]
 
-        return self._image_pixels_to_features(
+        assert self.vision_tower is not None
+        pixel_values = image_input["data"]
+        image_features = self._image_pixels_to_features(
             self.vision_tower,
             pixel_values,
         )
 
-    def _process_image_input(
-        self,
-        image_input: PaliGemmaImageInputs,
-    ) -> torch.Tensor:
-
-        assert self.vision_tower is not None
-        image_features = self._process_image_pixels(image_input, )
-
         return self.multi_modal_projector(image_features)
 
     def forward(self,

+ 70 - 31
aphrodite/modeling/models/phi3v.py

@@ -68,6 +68,33 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
                                                      projection_dim=768)
 
 
+class Phi3VImagePixelInputs(TypedDict):
+    type: Literal["pixel_values"]
+    data: Union[torch.Tensor, List[torch.Tensor]]
+    """
+    Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
+    Note that `num_patches` may be different for each batch, in which case
+    the data is passed as a list instead of a batched tensor.
+    """
+
+    image_sizes: torch.Tensor
+    """
+    Shape: `(batch_size, 2)`
+    This should be in `(height, width)` format.
+    """
+
+
+class Phi3VImageEmbeddingInputs(TypedDict):
+    type: Literal["image_embeds"]
+    data: Union[torch.Tensor, List[torch.Tensor]]
+    """Shape: `(batch_size, image_feature_size, hidden_size)`
+    `hidden_size` must match the hidden size of language model backbone.
+    """
+
+
+Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs]
+
+
 class Phi3ImageEmbeddingBase(nn.Module):
 
     def __init__(self) -> None:
@@ -254,23 +281,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
         return image_features_hd_newline
 
 
-class Phi3VImagePixelInputs(TypedDict):
-    type: Literal["pixel_values"]
-    data: Union[torch.Tensor, List[torch.Tensor]]
-    """
-    Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
-
-    Note that `num_patches` may be different for each batch, in which case
-    the data is passed as a list instead of a batched tensor.
-    """
-
-    image_sizes: torch.Tensor
-    """
-    Shape: `(batch_size, 2)`
-
-    This should be in `(height, width)` format.
-    """
-
 
 # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
 def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
@@ -386,7 +396,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
                                                           input_width=w,
                                                           input_height=h)
     elif isinstance(image_data, torch.Tensor):
-        raise NotImplementedError("Embeddings input is not supported yet")
+        image_feature_size = image_data.shape[0]
     else:
         raise TypeError(f"Invalid image type: {type(image_data)}")
 
@@ -490,25 +500,55 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
         return data
 
     def _parse_and_validate_image_input(
-            self, **kwargs: object) -> Optional[Phi3VImagePixelInputs]:
+            self, **kwargs: object) -> Optional[Phi3VImageInputs]:
         pixel_values = kwargs.pop("pixel_values", None)
         image_sizes = kwargs.pop("image_sizes", None)
+        image_embeds = kwargs.pop("image_embeds", None)
 
         if pixel_values is None:
             return None
 
-        if not isinstance(pixel_values, (torch.Tensor, list)):
-            raise ValueError("Incorrect type of pixel values. "
-                             f"Got type: {type(pixel_values)}")
+        if pixel_values is None and image_embeds is None:
+            return None
+
+        if pixel_values is not None:
+            if not isinstance(pixel_values, (torch.Tensor, list)):
+                raise ValueError("Incorrect type of pixel values. "
+                                 f"Got type: {type(pixel_values)}")
+
+            if not isinstance(image_sizes, torch.Tensor):
+                raise ValueError("Incorrect type of image sizes. "
+                                 f"Got type: {type(image_sizes)}")
+
+            return Phi3VImagePixelInputs(
+                type="pixel_values",
+                data=self._validate_pixel_values(pixel_values),
+                image_sizes=self._validate_image_sizes(image_sizes))
+
+        if image_embeds is not None:
+            if not isinstance(image_embeds, torch.Tensor):
+                raise ValueError("Incorrect type of image embeddings. "
+                                 f"Got type: {type(image_embeds)}")
+            return Phi3VImageEmbeddingInputs(
+                type="image_embeds",
+                data=image_embeds,
+            )
+
+        raise AssertionError("This line should be unreachable.")
+
+    def _process_image_input(
+        self,
+        image_input: Phi3VImageInputs,
+    ) -> torch.Tensor:
+
+        if image_input["type"] == "image_embeds":
+            return image_input["data"]
 
-        if not isinstance(image_sizes, torch.Tensor):
-            raise ValueError("Incorrect type of image sizes. "
-                             f"Got type: {type(image_sizes)}")
+        assert self.vision_embed_tokens is not None
+        image_embeds = self.vision_embed_tokens(image_input["data"],
+                                                image_input["image_sizes"])
 
-        return Phi3VImagePixelInputs(
-            type="pixel_values",
-            data=self._validate_pixel_values(pixel_values),
-            image_sizes=self._validate_image_sizes(image_sizes))
+        return image_embeds
 
     def forward(self,
                 input_ids: torch.Tensor,
@@ -520,8 +560,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
         image_input = self._parse_and_validate_image_input(**kwargs)
 
         if image_input is not None:
-            vision_embeddings = self.vision_embed_tokens(
-                image_input["data"], image_input["image_sizes"])
+            vision_embeddings = self._process_image_input(image_input)
             inputs_embeds = self.model.get_input_embeddings(input_ids)
             inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
                                                     vision_embeddings,

+ 7 - 1
aphrodite/modeling/models/siglip.py

@@ -97,7 +97,13 @@ def input_processor_for_siglip(
     tokenizer = cached_get_tokenizer(model_config.tokenizer)
 
     if image_feature_size_override is None:
-        image_feature_size = get_siglip_image_feature_size(hf_config)
+        image_data = multi_modal_data["image"]
+        if isinstance(image_data, Image.Image):
+            image_feature_size = get_siglip_image_feature_size(hf_config)
+        elif isinstance(image_data, torch.Tensor):
+            image_feature_size = image_data.shape[0]
+        else:
+            raise TypeError(f"Invalid image type: {type(image_data)}")
     else:
         image_feature_size = image_feature_size_override
 

+ 1 - 1
aphrodite/multimodal/image.py

@@ -126,7 +126,7 @@ class ImagePlugin(MultiModalPlugin):
 
             return MultiModalInputs(batch_data)
         elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor):
-            raise NotImplementedError("Embeddings input is not supported yet")
+            return MultiModalInputs({"image_embeds": data})
 
         raise TypeError(f"Invalid image type: {type(data)}")