Browse Source

VLM: add support for LLaVA-Onevision model (#1100)

* VLM: add support for LLaVA-Onevision model

* add tests
AlpinDale 1 month ago
parent
commit
a5bfc2bc3d

+ 1 - 1
aphrodite/assets/video.py

@@ -80,7 +80,7 @@ class VideoAsset:
         return ret
         return ret
 
 
     @property
     @property
-    def np_ndarrays(self) -> List[npt.NDArray]:
+    def np_ndarrays(self) -> npt.NDArray:
         video_path = (self.local_path if self.local_path else
         video_path = (self.local_path if self.local_path else
                       download_video_asset(self.name))
                       download_video_asset(self.name))
         ret = video_to_ndarrays(video_path, self.num_frames)
         ret = video_to_ndarrays(video_path, self.num_frames)

+ 2 - 0
aphrodite/modeling/models/__init__.py

@@ -88,6 +88,8 @@ _MULTIMODAL_MODELS = {
                                           "LlavaNextForConditionalGeneration"),
                                           "LlavaNextForConditionalGeneration"),
     "LlavaNextVideoForConditionalGeneration":
     "LlavaNextVideoForConditionalGeneration":
     ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
     ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
+    "LlavaOnevisionForConditionalGeneration":
+    ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
     "MiniCPMV": ("minicpmv", "MiniCPMV"),
     "MiniCPMV": ("minicpmv", "MiniCPMV"),
     "PaliGemmaForConditionalGeneration": ("paligemma",
     "PaliGemmaForConditionalGeneration": ("paligemma",
                                           "PaliGemmaForConditionalGeneration"),
                                           "PaliGemmaForConditionalGeneration"),

+ 19 - 0
aphrodite/modeling/models/clip.py

@@ -2,6 +2,7 @@
 within a vision language model."""
 within a vision language model."""
 from typing import Iterable, List, Optional, Tuple, Union
 from typing import Iterable, List, Optional, Tuple, Union
 
 
+import numpy as np
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 from PIL import Image
 from PIL import Image
@@ -84,6 +85,24 @@ def dummy_image_for_clip(
     return {"image": image if num_images == 1 else [image] * num_images}
     return {"image": image if num_images == 1 else [image] * num_images}
 
 
 
 
+def dummy_video_for_clip(
+    hf_config: CLIPVisionConfig,
+    num_frames: int,
+    *,
+    image_width_override: Optional[int] = None,
+    image_height_override: Optional[int] = None,
+):
+    pil_frame = dummy_image_for_clip(
+        hf_config,
+        num_images=1,
+        image_width_override=image_width_override,
+        image_height_override=image_height_override)
+    np_frame = np.array(pil_frame["image"])
+    mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
+    mm_data = {"video": mm_data_per_video}
+    return mm_data
+
+
 def input_processor_for_clip(
 def input_processor_for_clip(
     model_config: ModelConfig,
     model_config: ModelConfig,
     hf_config: CLIPVisionConfig,
     hf_config: CLIPVisionConfig,

+ 869 - 0
aphrodite/modeling/models/llava_onevision.py

@@ -0,0 +1,869 @@
+import math
+from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
+                    TypedDict, Union)
+
+import numpy as np
+import torch
+import torch.nn as nn
+from PIL import Image
+from transformers import (CLIPVisionConfig, LlavaOnevisionConfig,
+                          SiglipVisionConfig)
+from transformers.models.llava_onevision.modeling_llava_onevision import (
+    get_anyres_image_grid_shape, unpad_image)
+from typing_extensions import NotRequired
+
+from aphrodite.attention import AttentionMetadata
+from aphrodite.common.config import CacheConfig, MultiModalConfig
+from aphrodite.common.sequence import IntermediateTensors
+from aphrodite.common.utils import is_list_of
+from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from aphrodite.modeling.layers.activation import get_act_fn
+from aphrodite.modeling.layers.sampler import SamplerOutput
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.multimodal import MULTIMODAL_REGISTRY
+from aphrodite.multimodal.utils import (cached_get_tokenizer,
+                                        repeat_and_pad_placeholder_tokens)
+from aphrodite.quantization.base_config import QuantizationConfig
+
+from .clip import (CLIPVisionModel, dummy_seq_data_for_clip,
+                   dummy_video_for_clip, get_clip_image_feature_size,
+                   get_clip_patch_grid_length, input_processor_for_clip)
+from .interfaces import SupportsMultiModal
+from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
+                     dummy_video_for_siglip, get_siglip_image_feature_size,
+                     get_siglip_patch_grid_length, input_processor_for_siglip)
+from .utils import (flatten_bn, group_weights_with_prefix,
+                    init_aphrodite_registered_model,
+                    merge_multimodal_embeddings)
+
+# Result in the max possible feature size (2x2 grid of 336x336px tiles)
+MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
+# For profile run
+_MAX_FRAMES_PER_VIDEO = 16
+_MAX_NUM_VIDEOS = 1
+
+
+class LlavaOnevisionVideoPixelInputs(TypedDict):
+    type: Literal["pixel_values_videos"]
+    data: Union[torch.Tensor, List[torch.Tensor]]
+    """
+    Shape: `(batch_size, num_frames, num_channels, height, width)`
+    Note that `num_frames` may be different for each batch, in which case
+    the data is passed as a list instead of a batched tensor.
+    Note that it only supports one video input for one batch.
+    """
+
+
+class LlavaOnevisionImagePixelInputs(TypedDict):
+    type: Literal["pixel_values"]
+    data: Union[torch.Tensor, List[torch.Tensor]]
+    """
+    Shape:
+    `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
+    Note that `num_patches` may be different per batch and image,
+    in which case the data is passed as a list instead of a batched tensor.
+    """
+    image_sizes: NotRequired[torch.Tensor]
+    """
+    Shape: `(batch_size * num_images, 2)`
+    This should be in `(height, width)` format.
+    """
+
+
+class LlavaOnevisionImageEmbeddingInputs(TypedDict):
+    type: Literal["image_embeds"]
+    data: torch.Tensor
+    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
+    `hidden_size` must match the hidden size of language model backbone.
+    """
+
+
+LlavaOnevisionImageInputs = Union[
+    LlavaOnevisionImagePixelInputs, LlavaOnevisionImageEmbeddingInputs
+]
+LlavaOnevisionMultiInputs = Union[
+    LlavaOnevisionImageInputs, LlavaOnevisionVideoPixelInputs
+]
+
+
+def _get_llava_onevision_image_unppaded_feature_size(
+    height, width, patches, scale_height, scale_width
+):
+    current_height = patches * scale_height
+    current_width = patches * scale_width
+    original_aspect_ratio = width / height
+    current_aspect_ratio = current_width / current_height
+    if original_aspect_ratio > current_aspect_ratio:
+        new_height = int(height * (current_width / width))
+        padding = (current_height - new_height) // 2
+        current_height -= padding * 2
+    else:
+        new_width = int(width * (current_height / height))
+        padding = (current_width - new_width) // 2
+        current_width -= padding * 2
+    unpadded_features = current_height * current_width
+    newline_features = current_height
+    ratio = math.sqrt(current_height * current_width / (9 * patches**2))
+    if ratio > 1.1:
+        unpadded_features = int(current_height // ratio) * int(
+            current_width // ratio
+        )
+        newline_features = int(current_height // ratio)
+    return (unpadded_features, newline_features)
+
+
+def get_llava_onevision_image_feature_size(
+    hf_config: LlavaOnevisionConfig,
+    *,
+    input_height: int,
+    input_width: int,
+) -> int:
+    vision_config = hf_config.vision_config
+    if isinstance(vision_config, CLIPVisionConfig):
+        num_patches = get_clip_patch_grid_length(
+            image_size=vision_config.image_size,
+            patch_size=vision_config.patch_size,
+        )
+        base_feature_size = get_clip_image_feature_size(vision_config)
+    elif isinstance(vision_config, SiglipVisionConfig):
+        num_patches = get_siglip_patch_grid_length(
+            image_size=vision_config.image_size,
+            patch_size=vision_config.patch_size,
+        )
+        base_feature_size = get_siglip_image_feature_size(vision_config)
+    else:
+        msg = f"Unsupported vision config: {type(vision_config)}"
+        raise NotImplementedError(msg)
+    strategy = hf_config.vision_feature_select_strategy
+    if strategy == "default":
+        base_feature_size -= 1
+    elif strategy == "full":
+        pass
+    else:
+        raise ValueError(f"Unexpected select feature strategy: {strategy}")
+    num_patch_height, num_patch_width = get_anyres_image_grid_shape(
+        image_size=(input_height, input_width),
+        grid_pinpoints=hf_config.image_grid_pinpoints,
+        patch_size=vision_config.image_size,
+    )
+    (
+        unpadded_feature_size,
+        newline_feature_size,
+    ) = _get_llava_onevision_image_unppaded_feature_size(
+        input_height,
+        input_width,
+        num_patches,
+        num_patch_height,
+        num_patch_width,
+    )
+    return unpadded_feature_size + newline_feature_size + base_feature_size
+
+
+def get_max_llava_onevision_image_tokens(ctx: InputContext):
+    return get_llava_onevision_image_feature_size(
+        ctx.get_hf_config(LlavaOnevisionConfig),
+        input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
+        input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
+    )
+
+
+def get_llava_onevision_video_frame_feature_size(
+    hf_config: LlavaOnevisionConfig
+) -> int:
+    # Support both CLIPVisionConfig and SiglipVisionConfig
+    image_size = hf_config.vision_config.image_size
+    patch_size = hf_config.vision_config.patch_size
+    spatial_pool_stride = (
+        hf_config.spatial_pool_stride
+        if hasattr(hf_config, "spatial_pool_stride")
+        else 2
+    )
+    height = width = image_size // patch_size
+    return math.ceil(height / spatial_pool_stride) * math.ceil(
+        width / spatial_pool_stride
+    )
+
+
+def get_llava_onevision_video_tokens(ctx: InputContext, num_frames: int) -> int:
+    hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
+    # TODO: support configuring (not supported by HF right now)
+    num_token_image_newline = 1
+    tokens_per_frame = get_llava_onevision_video_frame_feature_size(hf_config)
+    video_feature_size = num_frames * tokens_per_frame + num_token_image_newline
+    return video_feature_size
+
+
+def get_max_llava_onevision_video_tokens(ctx: InputContext) -> int:
+    return get_llava_onevision_video_tokens(ctx, _MAX_FRAMES_PER_VIDEO)
+
+
+def dummy_data_for_llava_onevision(
+    ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
+):
+    hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
+    vision_config = hf_config.vision_config
+    # TODO: support multiple videos
+    num_videos = mm_counts["video"]
+    if num_videos > _MAX_NUM_VIDEOS:
+        raise NotImplementedError(
+            f"Only {_MAX_NUM_VIDEOS} videos are supported"
+        )
+    # TODO: support configuring the number of frames
+    num_frames = _MAX_FRAMES_PER_VIDEO
+    video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
+    if isinstance(vision_config, CLIPVisionConfig):
+        seq_data = dummy_seq_data_for_clip(
+            vision_config,
+            seq_len,
+            num_videos,
+            image_token_id=hf_config.video_token_index,
+            image_feature_size_override=video_feature_size,
+        )
+        mm_data = dummy_video_for_clip(vision_config, num_frames=num_frames)
+        return seq_data, mm_data
+    elif isinstance(vision_config, SiglipVisionConfig):
+        seq_data = dummy_seq_data_for_siglip(
+            vision_config,
+            seq_len,
+            num_videos,
+            image_token_id=hf_config.video_token_index,
+            image_feature_size_override=video_feature_size,
+        )
+        mm_data = dummy_video_for_siglip(vision_config, num_frames=num_frames)
+        return seq_data, mm_data
+    msg = f"Unsupported vision config: {type(vision_config)}"
+    raise NotImplementedError(msg)
+
+
+def input_processor_when_multimodal_input_image(
+    ctx: InputContext, llm_inputs: LLMInputs
+):
+    multi_modal_data = llm_inputs.get("multi_modal_data")
+    if multi_modal_data is None or "image" not in multi_modal_data:
+        return llm_inputs
+    model_config = ctx.model_config
+    hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
+    vision_config = hf_config.vision_config
+    image_data = multi_modal_data["image"]
+    if isinstance(image_data, Image.Image):
+        width, height = image_data.size
+        image_feature_size = get_llava_onevision_image_feature_size(
+            hf_config,
+            input_height=height,
+            input_width=width,
+        )
+    elif is_list_of(image_data, Image.Image):
+        image_feature_size = [
+            get_llava_onevision_image_feature_size(
+                hf_config, input_height=img.height, input_width=img.width
+            )
+            for img in image_data
+        ]
+    elif isinstance(image_data, torch.Tensor):
+        num_images, image_feature_size, hidden_size = image_data.shape
+    elif is_list_of(image_data, torch.Tensor):
+        image_feature_size = [item.shape[1] for item in image_data]
+    else:
+        raise TypeError(f"Invalid image type: {type(image_data)}")
+    vision_config = hf_config.vision_config
+    if isinstance(vision_config, CLIPVisionConfig):
+        return input_processor_for_clip(
+            model_config,
+            vision_config,
+            llm_inputs,
+            image_token_id=hf_config.image_token_index,
+            image_feature_size_override=image_feature_size,
+        )
+    elif isinstance(vision_config, SiglipVisionConfig):
+        return input_processor_for_siglip(
+            model_config,
+            vision_config,
+            llm_inputs,
+            image_token_id=hf_config.image_token_index,
+            image_feature_size_override=image_feature_size,
+        )
+    msg = f"Unsupported vision config: {type(vision_config)}"
+    raise NotImplementedError(msg)
+
+
+def input_processor_when_multimodal_input_video(
+    ctx: InputContext, llm_inputs: LLMInputs
+):
+    multi_modal_data = llm_inputs.get("multi_modal_data")
+    if multi_modal_data is None or "video" not in multi_modal_data:
+        return llm_inputs
+    video_data = multi_modal_data["video"]
+    model_config = ctx.model_config
+    hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
+    vision_config = hf_config.vision_config
+    if isinstance(video_data, np.ndarray):
+        # Supports both CLIP and Siglip
+        num_frames = video_data.shape[0]
+        video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
+        tokenizer = cached_get_tokenizer(model_config.tokenizer)
+        new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
+            tokenizer,
+            llm_inputs.get("prompt"),
+            llm_inputs["prompt_token_ids"],
+            placeholder_token_id=hf_config.video_token_index,
+            repeat_count=video_feature_size,
+        )
+        return LLMInputs(
+            prompt_token_ids=new_token_ids,
+            prompt=new_prompt,
+            multi_modal_data=multi_modal_data,
+        )
+    elif is_list_of(video_data, np.ndarray):
+        raise NotImplementedError("Processing multiple videos is not supported")
+    msg = f"Unsupported vision config: {type(vision_config)}"
+    raise NotImplementedError(msg)
+
+
+def input_processor_for_llava_onevision(
+    ctx: InputContext, llm_inputs: LLMInputs
+):
+    multi_modal_data = llm_inputs.get("multi_modal_data")
+    if multi_modal_data is None or (
+        "video" not in multi_modal_data and "image" not in multi_modal_data
+    ):
+        return llm_inputs
+    if "image" in multi_modal_data:
+        return input_processor_when_multimodal_input_image(ctx, llm_inputs)
+    if "video" in multi_modal_data:
+        return input_processor_when_multimodal_input_video(ctx, llm_inputs)
+    msg = "Unsupported multi data type"
+    raise NotImplementedError(msg)
+
+
+def _init_vision_tower(hf_config: LlavaOnevisionConfig):
+    vision_config = hf_config.vision_config
+    # Initialize the vision tower only up to the required feature layer
+    vision_feature_layer = hf_config.vision_feature_layer
+    if vision_feature_layer < 0:
+        num_hidden_layers = (
+            hf_config.vision_config.num_hidden_layers + vision_feature_layer + 1
+        )
+    else:
+        num_hidden_layers = vision_feature_layer + 1
+    if isinstance(vision_config, CLIPVisionConfig):
+        return CLIPVisionModel(
+            vision_config,
+            num_hidden_layers_override=num_hidden_layers,
+        )
+    elif isinstance(vision_config, SiglipVisionConfig):
+        return SiglipVisionModel(
+            vision_config,
+            num_hidden_layers_override=num_hidden_layers,
+        )
+    msg = f"Unsupported vision config: {type(vision_config)}"
+    raise NotImplementedError(msg)
+
+
+class LlavaOnevisionMultiModalProjector(nn.Module):
+    def __init__(self, config: LlavaOnevisionConfig):
+        super().__init__()
+        self.linear_1 = nn.Linear(
+            config.vision_config.hidden_size,
+            config.text_config.hidden_size,
+            bias=True,
+        )
+        self.act = get_act_fn(config.projector_hidden_act)
+        self.linear_2 = nn.Linear(
+            config.text_config.hidden_size,
+            config.text_config.hidden_size,
+            bias=True,
+        )
+
+    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.linear_1(image_features)
+        hidden_states = self.act(hidden_states)
+        hidden_states = self.linear_2(hidden_states)
+        return hidden_states
+
+
+@MULTIMODAL_REGISTRY.register_image_input_mapper()
+@MULTIMODAL_REGISTRY.register_input_mapper("video")
+@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
+    "image", get_max_llava_onevision_image_tokens
+)
+@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
+    "video", get_max_llava_onevision_video_tokens
+)
+@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_onevision)
+@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_onevision)
+class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal):
+    def __init__(
+        self,
+        config: LlavaOnevisionConfig,
+        multimodal_config: MultiModalConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.multimodal_config = multimodal_config
+        # Initialize the vision tower only up to the required feature layer
+        self.vision_tower = _init_vision_tower(config)
+        self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
+        self.language_model = init_aphrodite_registered_model(
+            config.text_config, cache_config, quant_config
+        )
+        self.image_newline = nn.Parameter(
+            torch.empty(config.text_config.hidden_size)
+        )
+
+    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
+        expected_dims = (2,)
+
+        def _validate_shape(d: torch.Tensor):
+            actual_dims = tuple(d.shape)
+            if actual_dims != expected_dims:
+                expected_expr = str(expected_dims)
+                raise ValueError(
+                    f"The expected shape of image sizes per image per batch "
+                    f"is {expected_expr}. You supplied {tuple(d.shape)}."
+                )
+
+        for d in data:
+            _validate_shape(d)
+        return data
+
+    def _validate_image_pixel_values(
+        self, data: Union[torch.Tensor, List[torch.Tensor]]
+    ) -> Union[torch.Tensor, List[torch.Tensor]]:
+        h = w = self.config.vision_config.image_size
+        expected_dims = (3, h, w)
+
+        def _validate_shape(d: torch.Tensor):
+            actual_dims = tuple(d.shape[1:])
+            if actual_dims != expected_dims:
+                expected_expr = ("num_patches", *map(str, expected_dims))
+                raise ValueError(
+                    "The expected shape of pixel values per image per batch "
+                    f"is {expected_expr}. You supplied {tuple(d.shape)}."
+                )
+
+        for d in data:
+            _validate_shape(d)
+        return data
+
+    def _parse_and_validate_image_input(
+        self, **kwargs: object
+    ) -> Optional[LlavaOnevisionImageInputs]:
+        pixel_values = kwargs.pop("pixel_values", None)
+        image_sizes = kwargs.pop("image_sizes", None)
+        image_embeds = kwargs.pop("image_embeds", None)
+        if pixel_values is None and image_embeds is None:
+            return None
+        if pixel_values is not None:
+            if not isinstance(pixel_values, (torch.Tensor, list)):
+                raise ValueError(
+                    "Incorrect type of pixel values. "
+                    f"Got type: {type(pixel_values)}"
+                )
+            if not isinstance(image_sizes, (torch.Tensor, list)):
+                raise ValueError(
+                    "Incorrect type of image sizes. "
+                    f"Got type: {type(image_sizes)}"
+                )
+            return LlavaOnevisionImagePixelInputs(
+                type="pixel_values",
+                data=self._validate_image_pixel_values(
+                    flatten_bn(pixel_values)
+                ),
+                image_sizes=self._validate_image_sizes(
+                    flatten_bn(image_sizes, concat=True)
+                ),
+            )
+        if image_embeds is not None:
+            if not isinstance(image_embeds, torch.Tensor):
+                raise ValueError(
+                    "Incorrect type of image embeds. "
+                    f"Got type: {type(image_embeds)}"
+                )
+            return LlavaOnevisionImageEmbeddingInputs(
+                type="image_embeds",
+                data=flatten_bn(image_embeds),
+            )
+        raise AssertionError("This line should be unreachable.")
+
+    def _validate_video_pixel_values(
+        self, data: Union[torch.Tensor, List[torch.Tensor]]
+    ) -> Union[torch.Tensor, List[torch.Tensor]]:
+        h = w = self.config.vision_config.image_size
+        expected_dims = (3, h, w)
+
+        def _validate_shape(d: torch.Tensor):
+            actual_dims = tuple(d.shape[2:])
+            if actual_dims != expected_dims:
+                expected_expr = ("num_frames", *map(str, expected_dims))
+                raise ValueError(
+                    "The expected shape of pixel values in each video frame "
+                    f"is {expected_expr}. You supplied {tuple(d.shape)}."
+                )
+
+        for d in data:
+            _validate_shape(d)
+        return data
+
+    def _parse_and_validate_video_input(
+        self, **kwargs: object
+    ) -> Optional[LlavaOnevisionVideoPixelInputs]:
+        """
+        A legal video input should have the following dimensions:
+        {
+            "pixel_values_videos" :
+                List[b, Tensor(nb_frames, nb_channels, height, width)]
+        }
+        """
+        pixel_values = kwargs.pop("pixel_values_videos", None)
+        if pixel_values is None:
+            return None
+        if not (
+            is_list_of(pixel_values, (torch.Tensor))  # different shape videos
+            or isinstance(pixel_values, torch.Tensor)
+        ):  # same shape videos
+            raise ValueError(
+                "Incorrect type of pixel values. "
+                f"Got type: {type(pixel_values)}"
+            )
+        return LlavaOnevisionVideoPixelInputs(
+            type="pixel_values_videos",
+            data=pixel_values,
+        )
+
+    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
+        modalities = {}
+        if "pixel_values" in kwargs:
+            modalities["images"] = self._parse_and_validate_image_input(
+                **kwargs
+            )
+        if "pixel_values_videos" in kwargs:
+            modalities["videos"] = self._parse_and_validate_video_input(
+                **kwargs
+            )
+        return modalities
+
+    def _select_image_features(
+        self, image_features: torch.Tensor, *, strategy: str
+    ) -> torch.Tensor:
+        if strategy == "default":
+            return image_features[:, 1:]
+        elif strategy == "full":
+            return image_features
+        raise ValueError(f"Unexpected select feature strategy: {strategy}")
+
+    def _image_pixels_to_features(
+        self,
+        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
+        pixel_values: torch.Tensor,
+    ) -> torch.Tensor:
+        # NOTE: we skip the step to select the vision feature layer since
+        # this is already done inside the vision tower
+        image_features = vision_tower(pixel_values)
+        return self._select_image_features(
+            image_features,
+            strategy=self.config.vision_feature_select_strategy,
+        )
+
+    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
+    def _merge_image_patch_embeddings(
+        self,
+        image_size: torch.Tensor,
+        patch_embeddings: torch.Tensor,
+        *,
+        image_newline=None,
+        vision_aspect_ratio="anyres_max_9",
+        strategy: str,
+    ) -> torch.Tensor:
+        if strategy == "flat":
+            return patch_embeddings.flatten(0, 1)
+        if strategy.startswith("spatial"):
+            height = width = (
+                self.config.vision_config.image_size
+                // self.config.vision_config.patch_size
+            )
+            base_patch_embeds = patch_embeddings[0]
+            if height * width != base_patch_embeds.shape[0]:
+                raise ValueError(
+                    "The number of patches is not consistent with the "
+                    "image size."
+                )
+            if patch_embeddings.shape[0] > 1:
+                other_patch_embeds = patch_embeddings[1:]
+                # Move to CPU to avoid floating-point errors
+                orig_height, orig_width = image_size.tolist()
+                # image_aspect_ratio == "anyres"
+                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
+                    (orig_height, orig_width),
+                    self.config.image_grid_pinpoints,
+                    self.config.vision_config.image_size,
+                )
+                num_patches = num_patch_height * num_patch_width
+                # Image patches might be padded for batch processing
+                other_patch_embeds = other_patch_embeds[:num_patches].view(
+                    num_patch_height, num_patch_width, height, width, -1
+                )
+                if "unpad" in strategy:
+                    other_patch_embeds = (
+                        other_patch_embeds.permute(4, 0, 2, 1, 3)
+                        .contiguous()
+                        .flatten(1, 2)
+                        .flatten(2, 3)
+                    )
+                    other_patch_embeds = unpad_image(
+                        other_patch_embeds, (orig_height, orig_width)
+                    )
+                    max_num_patches = int(
+                        vision_aspect_ratio.removeprefix("anyres_max_")
+                    )
+                    channels, curr_height, curr_width = other_patch_embeds.shape
+                    ratio = math.sqrt(
+                        curr_height * curr_width / (max_num_patches * height**2)
+                    )
+                    if ratio > 1.1:
+                        other_patch_embeds = other_patch_embeds[None]
+                        other_patch_embeds = nn.functional.interpolate(
+                            other_patch_embeds,
+                            [
+                                int(curr_height // ratio),
+                                int(curr_width // ratio),
+                            ],
+                            mode="bilinear",
+                        )[0]
+                    if image_newline is not None:
+                        other_patch_embeds = torch.cat(
+                            (
+                                other_patch_embeds,
+                                image_newline[:, None, None]
+                                .expand(*other_patch_embeds.shape[:-1], 1)
+                                .to(other_patch_embeds.device),
+                            ),
+                            dim=-1,
+                        )
+                    other_patch_embeds = other_patch_embeds.flatten(
+                        1, 2
+                    ).transpose(0, 1)
+                else:
+                    other_patch_embeds = (
+                        other_patch_embeds.permute(0, 2, 1, 3, 4)
+                        .contiguous()
+                        .flatten(0, 3)
+                    )
+                merged_patch_embeddings = torch.cat(
+                    (base_patch_embeds, other_patch_embeds), dim=0
+                )
+            else:
+                if "unpad" in strategy:
+                    merged_patch_embeddings = torch.cat(
+                        (
+                            base_patch_embeds,
+                            self.image_newline[None].to(
+                                base_patch_embeds.device
+                            ),
+                        ),
+                        dim=0,
+                    )
+                else:
+                    merged_patch_embeddings = base_patch_embeds
+            return merged_patch_embeddings
+        raise ValueError(f"Unexpected patch merge strategy: {strategy}")
+
+    def _process_image_pixels(
+        self,
+        inputs: LlavaOnevisionImagePixelInputs,
+    ) -> Union[torch.Tensor, List[torch.Tensor]]:
+        assert self.vision_tower is not None
+        pixel_values = inputs["data"]
+        if isinstance(pixel_values, torch.Tensor):
+            b, num_patches, c, h, w = pixel_values.shape
+            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
+            stacked_image_features = self._image_pixels_to_features(
+                self.vision_tower, stacked_pixel_values
+            )
+            stacked_patch_embeddings = self.multi_modal_projector(
+                stacked_image_features
+            )
+            return stacked_patch_embeddings.view(
+                b, num_patches, *stacked_patch_embeddings.shape[1:]
+            )
+        num_patches_per_batch = [v.shape[0] for v in pixel_values]
+        stacked_pixel_values = torch.cat(pixel_values)
+        stacked_image_features = self._image_pixels_to_features(
+            self.vision_tower, stacked_pixel_values
+        )
+        return [
+            self.multi_modal_projector(image_features)
+            for image_features in torch.split(
+                stacked_image_features, num_patches_per_batch
+            )
+        ]
+
+    def _process_image_input(
+        self,
+        image_input: LlavaOnevisionImageInputs,
+    ) -> Union[torch.Tensor, List[torch.Tensor]]:
+        if image_input["type"] == "image_embeds":
+            return [image_input["data"]]
+        patch_embeddings = self._process_image_pixels(image_input)
+        image_sizes = image_input.get("image_sizes")
+        if image_sizes is None:
+            batch_size = len(image_input["data"])
+            vision_config = self.config.vision_config
+            default_height = default_width = vision_config.image_size
+            image_sizes = torch.as_tensor(
+                [[default_height, default_width] for _ in range(batch_size)]
+            )
+        return [
+            self._merge_image_patch_embeddings(
+                image_sizes[i],
+                patch_features_batch,
+                image_newline=self.image_newline,
+                strategy="spatial_unpad",
+            )
+            for i, patch_features_batch in enumerate(patch_embeddings)
+        ]
+
+    def _video_pixels_to_features(
+        self,
+        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
+        pixel_values: torch.Tensor,
+    ) -> torch.Tensor:
+        # NOTE: we skip the step to select the vision feature layer since
+        # this is already done inside the vision tower
+        b, num_videos, frames, c, h, w = pixel_values.shape
+        assert num_videos == _MAX_NUM_VIDEOS
+        pixel_values = pixel_values.reshape(b * num_videos * frames, c, h, w)
+        video_features = vision_tower(pixel_values)
+        video_features = self._select_image_features(
+            video_features,
+            strategy=self.config.vision_feature_select_strategy,
+        )
+        video_features = self.multi_modal_projector(video_features)
+        video_features = self.apply_pooling(video_features)
+        video_features = video_features.reshape(
+            b, frames * video_features.shape[1], -1
+        )
+        image_newline = (
+            self.image_newline[None, None, :]
+            .repeat(b, 1, 1)
+            .to(video_features.device)
+        )
+        video_features = torch.cat((video_features, image_newline), dim=1)
+        video_features = video_features.flatten(0, 1)
+        return video_features
+
+    def _process_video_pixels(self, inputs: LlavaOnevisionVideoPixelInputs):
+        assert self.vision_tower is not None
+        video_pixels = inputs["data"]
+        # TODO: support multiple videos per input
+        if isinstance(video_pixels, torch.Tensor):
+            stacked_embeddings = self._video_pixels_to_features(
+                self.vision_tower, video_pixels
+            )
+            return stacked_embeddings
+        else:
+            raise ValueError(
+                f"Unsupported type of video input {type(video_pixels)}"
+            )
+
+    def apply_pooling(self, image_features, stride=2):
+        vision_config = self.config.vision_config
+        height = width = vision_config.image_size // vision_config.patch_size
+        batch_frames, _, dim = image_features.shape
+        image_features = image_features.view(batch_frames, height, width, -1)
+        image_features = image_features.permute(0, 3, 1, 2)
+        # TODO support other pooling types config
+        height, width = image_features.shape[2:]
+        scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)]
+        image_feature = nn.functional.interpolate(
+            image_features, size=scaled_shape, mode="bilinear"
+        )
+        image_feature = image_feature.permute(0, 2, 3, 1)
+        image_feature = image_feature.view(batch_frames, -1, dim)
+        return image_feature
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+        **kwargs: object,
+    ) -> SamplerOutput:
+        """Run forward pass for LlaVA-Onevision.
+        Args:
+            input_ids: Flattened (concatenated) input_ids corresponding to a
+                batch.
+            pixel_values_videos: Pixels in each frames for each input videos.
+        """
+        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
+        # merge video embeddings into input embeddings
+        if modalities:
+            inputs_embeds = self.language_model.model.get_input_embeddings(
+                input_ids
+            )
+            if "images" in modalities:
+                image_input = modalities["images"]
+                vision_embeddings = self._process_image_input(image_input)
+                inputs_embeds = merge_multimodal_embeddings(
+                    input_ids,
+                    inputs_embeds,
+                    vision_embeddings,
+                    self.config.image_token_index,
+                )
+            if "videos" in modalities:
+                video_input = modalities["videos"]
+                video_embeddings = self._process_video_pixels(video_input)
+                inputs_embeds = merge_multimodal_embeddings(
+                    input_ids,
+                    inputs_embeds,
+                    video_embeddings,
+                    self.config.video_token_index,
+                )
+            input_ids = None
+        else:
+            inputs_embeds = None
+        hidden_states = self.language_model.model(
+            input_ids,
+            positions,
+            kv_caches,
+            attn_metadata,
+            None,
+            inputs_embeds=inputs_embeds,
+        )
+        return hidden_states
+
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
+        return self.language_model.compute_logits(
+            hidden_states, sampling_metadata
+        )
+
+    def sample(
+        self,
+        logits: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        return self.language_model.sample(logits, sampling_metadata)
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        # prepare weight iterators for components
+        weights_group = group_weights_with_prefix(weights)
+        # load vision encoder
+        self.vision_tower.load_weights(weights_group["vision_tower"])
+        # load mlp projector
+        mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
+        for name, loaded_weight in weights_group["multi_modal_projector"]:
+            param = mlp_params_dict[name]
+            weight_loader = getattr(
+                param, "weight_loader", default_weight_loader
+            )
+            weight_loader(param, loaded_weight)
+        # load llm backbone
+        self.language_model.load_weights(weights_group["language_model"])

+ 19 - 0
aphrodite/modeling/models/siglip.py

@@ -4,6 +4,7 @@ within a vision language model."""
 import math
 import math
 from typing import Iterable, List, Optional, Tuple, Union
 from typing import Iterable, List, Optional, Tuple, Union
 
 
+import numpy as np
 import torch
 import torch
 from PIL import Image
 from PIL import Image
 from torch import nn
 from torch import nn
@@ -89,6 +90,24 @@ def dummy_image_for_siglip(
     return {"image": image if num_images == 1 else [image] * num_images}
     return {"image": image if num_images == 1 else [image] * num_images}
 
 
 
 
+def dummy_video_for_siglip(
+    hf_config: SiglipVisionConfig,
+    num_frames: int,
+    *,
+    image_width_override: Optional[int] = None,
+    image_height_override: Optional[int] = None,
+):
+    pil_frame = dummy_image_for_siglip(
+        hf_config,
+        num_images=1,
+        image_width_override=image_width_override,
+        image_height_override=image_height_override)
+    np_frame = np.array(pil_frame["image"])
+    mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
+    mm_data = {"video": mm_data_per_video}
+    return mm_data
+
+
 def input_processor_for_siglip(
 def input_processor_for_siglip(
     model_config: ModelConfig,
     model_config: ModelConfig,
     hf_config: SiglipVisionConfig,
     hf_config: SiglipVisionConfig,

+ 46 - 13
examples/vision/vision_example.py

@@ -57,7 +57,8 @@ def load_video_frames(video_path: str, num_frames: int) -> np.ndarray:
 
 
 
 
 # LLaVA-1.5
 # LLaVA-1.5
-def run_llava(question):
+def run_llava(question, modality):
+    assert modality == "image"
 
 
     prompt = f"USER: <image>\n{question}\nASSISTANT:"
     prompt = f"USER: <image>\n{question}\nASSISTANT:"
 
 
@@ -67,7 +68,8 @@ def run_llava(question):
 
 
 
 
 # LLaVA-1.6/LLaVA-NeXT
 # LLaVA-1.6/LLaVA-NeXT
-def run_llava_next(question):
+def run_llava_next(question, modality):
+    assert modality == "image"
 
 
     prompt = f"[INST] <image>\n{question} [/INST]"
     prompt = f"[INST] <image>\n{question} [/INST]"
     llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192)
     llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192)
@@ -77,15 +79,34 @@ def run_llava_next(question):
 
 
 # LlaVA-NeXT-Video
 # LlaVA-NeXT-Video
 # Currently only support for video input
 # Currently only support for video input
-def run_llava_next_video(question):
+def run_llava_next_video(question, modality):
+    assert modality == "video"
+
     prompt = f"USER: <video>\n{question} ASSISTANT:"
     prompt = f"USER: <video>\n{question} ASSISTANT:"
     llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf")
     llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf")
     stop_token_ids = None
     stop_token_ids = None
     return llm, prompt, stop_token_ids
     return llm, prompt, stop_token_ids
 
 
 
 
+# LLaVA-OneVision
+def run_llava_onevision(question, modality):
+    if modality == "video":
+        prompt = f"<|im_start|>user <video>\n{question}<|im_end|> \
+        <|im_start|>assistant\n"
+
+    elif modality == "image":
+        prompt = f"<|im_start|>user <image>\n{question}<|im_end|> \
+        <|im_start|>assistant\n"
+
+    llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
+              max_model_len=32768)
+    stop_token_ids = None
+    return llm, prompt, stop_token_ids
+
+
 # Fuyu
 # Fuyu
-def run_fuyu(question):
+def run_fuyu(question, modality):
+    assert modality == "image"
 
 
     prompt = f"{question}\n"
     prompt = f"{question}\n"
     llm = LLM(model="adept/fuyu-8b")
     llm = LLM(model="adept/fuyu-8b")
@@ -94,7 +115,8 @@ def run_fuyu(question):
 
 
 
 
 # Phi-3-Vision
 # Phi-3-Vision
-def run_phi3v(question):
+def run_phi3v(question, modality):
+    assert modality == "image"
 
 
     prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"  # noqa: E501
     prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"  # noqa: E501
     # Note: The default setting of max_num_seqs (256) and
     # Note: The default setting of max_num_seqs (256) and
@@ -113,7 +135,8 @@ def run_phi3v(question):
 
 
 
 
 # PaliGemma
 # PaliGemma
-def run_paligemma(question):
+def run_paligemma(question, modality):
+    assert modality == "image"
 
 
     # PaliGemma has special prompt format for VQA
     # PaliGemma has special prompt format for VQA
     prompt = "caption en"
     prompt = "caption en"
@@ -123,7 +146,8 @@ def run_paligemma(question):
 
 
 
 
 # Chameleon
 # Chameleon
-def run_chameleon(question):
+def run_chameleon(question, modality):
+    assert modality == "image"
 
 
     prompt = f"{question}<image>"
     prompt = f"{question}<image>"
     llm = LLM(model="facebook/chameleon-7b")
     llm = LLM(model="facebook/chameleon-7b")
@@ -132,7 +156,8 @@ def run_chameleon(question):
 
 
 
 
 # MiniCPM-V
 # MiniCPM-V
-def run_minicpmv(question):
+def run_minicpmv(question, modality):
+    assert modality == "image"
 
 
     # 2.0
     # 2.0
     # The official repo doesn't work yet, so we need to use a fork for now
     # The official repo doesn't work yet, so we need to use a fork for now
@@ -172,7 +197,9 @@ def run_minicpmv(question):
 
 
 
 
 # InternVL
 # InternVL
-def run_internvl(question):
+def run_internvl(question, modality):
+    assert modality == "image"
+
     model_name = "OpenGVLab/InternVL2-2B"
     model_name = "OpenGVLab/InternVL2-2B"
 
 
     llm = LLM(
     llm = LLM(
@@ -198,7 +225,8 @@ def run_internvl(question):
 
 
 
 
 # BLIP-2
 # BLIP-2
-def run_blip2(question):
+def run_blip2(question, modality):
+    assert modality == "image"
 
 
     # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
     # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
     # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
     # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
@@ -209,7 +237,8 @@ def run_blip2(question):
 
 
 
 
 # Qwen
 # Qwen
-def run_qwen_vl(question):
+def run_qwen_vl(question, modality):
+    assert modality == "image"
 
 
     llm = LLM(
     llm = LLM(
         model="Qwen/Qwen-VL",
         model="Qwen/Qwen-VL",
@@ -223,7 +252,9 @@ def run_qwen_vl(question):
 
 
 
 
 # Qwen2-VL
 # Qwen2-VL
-def run_qwen2_vl(question):
+def run_qwen2_vl(question, modality):
+    assert modality == "image"
+
     model_name = "Qwen/Qwen2-VL-7B-Instruct"
     model_name = "Qwen/Qwen2-VL-7B-Instruct"
 
 
     llm = LLM(
     llm = LLM(
@@ -258,6 +289,7 @@ model_example_map = {
     "llava": run_llava,
     "llava": run_llava,
     "llava-next": run_llava_next,
     "llava-next": run_llava_next,
     "llava-next-video": run_llava_next_video,
     "llava-next-video": run_llava_next_video,
+    "llava-onevision": run_llava_onevision,
     "fuyu": run_fuyu,
     "fuyu": run_fuyu,
     "phi3_v": run_phi3v,
     "phi3_v": run_phi3v,
     "paligemma": run_paligemma,
     "paligemma": run_paligemma,
@@ -307,7 +339,7 @@ def main(args):
     data = mm_input["data"]
     data = mm_input["data"]
     question = mm_input["question"]
     question = mm_input["question"]
 
 
-    llm, prompt, stop_token_ids = model_example_map[model](question)
+    llm, prompt, stop_token_ids = model_example_map[model](question, modality)
 
 
     # We set temperature to 0.2 so that outputs can be different
     # We set temperature to 0.2 so that outputs can be different
     # even when all prompts are identical when running batch inference.
     # even when all prompts are identical when running batch inference.
@@ -358,6 +390,7 @@ if __name__ == "__main__":
     parser.add_argument('--modality',
     parser.add_argument('--modality',
                         type=str,
                         type=str,
                         default="image",
                         default="image",
+                        choices=['image', 'video'],
                         help='Modality of the input.')
                         help='Modality of the input.')
     parser.add_argument('--num-frames',
     parser.add_argument('--num-frames',
                         type=int,
                         type=int,

+ 0 - 3
tests/models/decoder_only/vision_language/test_llava_next_video.py

@@ -105,9 +105,6 @@ def run_test(
         for asset in video_assets
         for asset in video_assets
     ]
     ]
 
 
-    for video in videos:
-        print(video.shape)
-
     if size_factors is not None:
     if size_factors is not None:
         inputs_per_video = [(
         inputs_per_video = [(
             [prompt for _ in size_factors],
             [prompt for _ in size_factors],

+ 380 - 0
tests/models/decoder_only/vision_language/test_llava_onevision.py

@@ -0,0 +1,380 @@
+from typing import List, Optional, Tuple, Type, overload
+
+import pytest
+import transformers
+from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
+                          BatchEncoding)
+
+from aphrodite.common.sequence import SampleLogprobs
+from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE
+from aphrodite.multimodal.utils import (rescale_image_size, rescale_video_size,
+                                        resize_video, sample_frames_from_video)
+
+from ....conftest import (VIDEO_ASSETS, AphroditeRunner, HfRunner,
+                          PromptImageInput, _VideoAssets)
+from ...utils import check_logprobs_close
+
+# Video test
+HF_VIDEO_PROMPTS = VIDEO_ASSETS.prompts(
+    {
+        "sample_demo_1": "<|im_start|>user <video>\nwhy is this video funny? \
+    <|im_end|><|im_start|>assistant\n"
+    }
+)
+models = ["llava-hf/llava-onevision-qwen2-7b-ov-hf"]
+
+
+def aphrodite_to_hf_output(
+    aphrodite_output: Tuple[List[int], str, Optional[SampleLogprobs]], model: str
+):
+    """Sanitize aphrodite output to be comparable with hf output."""
+    output_ids, output_str, out_logprobs = aphrodite_output
+    config = AutoConfig.from_pretrained(model)
+    video_token_id = config.video_token_index
+    tokenizer = AutoTokenizer.from_pretrained(model)
+    eos_token_id = tokenizer.eos_token_id
+    hf_output_ids = [
+        token_id
+        for idx, token_id in enumerate(output_ids)
+        if token_id != video_token_id or output_ids[idx - 1] != video_token_id
+    ]
+    hf_output_str = output_str
+    if hf_output_ids[-1] == eos_token_id:
+        hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
+    return hf_output_ids, hf_output_str, out_logprobs
+
+
+@overload
+def run_video_test(
+    hf_runner: Type[HfRunner],
+    aphrodite_runner: Type[AphroditeRunner],
+    video_assets: _VideoAssets,
+    model: str,
+    *,
+    size_factors: List[float],
+    dtype: str,
+    max_tokens: int,
+    num_logprobs: int,
+    num_frames: int,
+    tensor_parallel_size: int,
+    distributed_executor_backend: Optional[str] = None,
+):
+    ...
+
+
+@overload
+def run_video_test(
+    hf_runner: Type[HfRunner],
+    aphrodite_runner: Type[AphroditeRunner],
+    video_assets: _VideoAssets,
+    model: str,
+    *,
+    sizes: List[Tuple[int, int]],
+    dtype: str,
+    max_tokens: int,
+    num_logprobs: int,
+    num_frames: int,
+    tensor_parallel_size: int,
+    distributed_executor_backend: Optional[str] = None,
+):
+    ...
+
+
+def run_video_test(
+    hf_runner: Type[HfRunner],
+    aphrodite_runner: Type[AphroditeRunner],
+    video_assets: _VideoAssets,
+    model: str,
+    *,
+    size_factors: Optional[List[float]] = None,
+    sizes: Optional[List[Tuple[int, int]]] = None,
+    dtype: str,
+    max_tokens: int,
+    num_logprobs: int,
+    num_frames: int,
+    tensor_parallel_size: int,
+    distributed_executor_backend: Optional[str] = None,
+):
+    torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
+    videos = [
+        sample_frames_from_video(asset.np_ndarrays, num_frames)
+        for asset in video_assets
+    ]
+    if size_factors is not None:
+        inputs_per_video = [
+            (
+                [prompt for _ in size_factors],
+                [rescale_video_size(video, factor) for factor in size_factors],
+            )
+            for video, prompt in zip(videos, HF_VIDEO_PROMPTS)
+        ]
+    elif sizes is not None:
+        inputs_per_video = [
+            (
+                [prompt for _ in sizes],
+                [resize_video(video, size) for size in sizes],
+            )
+            for video, prompt in zip(videos, HF_VIDEO_PROMPTS)
+        ]
+    else:
+        raise ValueError("You must provide either `size_factors` or `sizes`")
+    # max_model_len should be greater than image_feature_size
+    with aphrodite_runner(
+        model,
+        dtype=dtype,
+        max_model_len=4096,
+        tensor_parallel_size=tensor_parallel_size,
+        distributed_executor_backend=distributed_executor_backend,
+        enforce_eager=True,
+    ) as aphrodite_model:
+        aphrodite_outputs_per_video = [
+            aphrodite_model.generate_greedy_logprobs(
+                prompts, max_tokens, num_logprobs=num_logprobs, videos=videos
+            )
+            for prompts, videos in inputs_per_video
+        ]
+
+    def process(hf_inputs: BatchEncoding):
+        hf_inputs["pixel_values_videos"] = hf_inputs["pixel_values_videos"].to(
+            torch_dtype
+        )  # type: ignore
+        return hf_inputs
+
+    with hf_runner(
+        model,
+        dtype=dtype,
+        postprocess_inputs=process,
+        auto_cls=AutoModelForVision2Seq,
+    ) as hf_model:
+        hf_outputs_per_video = [
+            hf_model.generate_greedy_logprobs_limit(
+                prompts, max_tokens, num_logprobs=num_logprobs, videos=videos
+            )
+            for prompts, videos in inputs_per_video
+        ]
+    for hf_outputs, aphrodite_outputs in zip(
+        hf_outputs_per_video, aphrodite_outputs_per_video
+    ):
+        # TODO: Check whether using original CLIPVisionModel can improve
+        # consistency against HF
+        check_logprobs_close(
+            outputs_0_lst=hf_outputs,
+            outputs_1_lst=[
+                aphrodite_to_hf_output(aphrodite_output, model)
+                for aphrodite_output in aphrodite_outputs
+            ],
+            name_0="hf",
+            name_1="aphrodite",
+        )
+
+
+@pytest.mark.skipif(
+    transformers.__version__ < "4.45",
+    reason="Waiting for next transformers release",
+)
+@pytest.mark.parametrize("model", models)
+@pytest.mark.parametrize(
+    "size_factors",
+    [
+        # No video
+        [],
+        # Single-scale
+        [1.0],
+        # Single-scale, batched
+        [1.0, 1.0, 1.0],
+        # Multi-scale
+        [0.25, 0.5, 1.0],
+    ],
+)
+@pytest.mark.parametrize("dtype", ["half"])
+@pytest.mark.parametrize("max_tokens", [128])
+@pytest.mark.parametrize("num_logprobs", [5])
+@pytest.mark.parametrize("num_frames", [16])
+def test_models(
+    hf_runner,
+    aphrodite_runner,
+    video_assets,
+    model,
+    size_factors,
+    dtype,
+    max_tokens,
+    num_logprobs,
+    num_frames,
+) -> None:
+    """Inference result should be the same between hf and aphrodite.
+    All the image fixtures for the test is under tests/videos.
+    For huggingface runner, we provide the np.ndarray as input.
+    For aphrodite runner, we provide MultiModalDataDict objects
+    and corresponding MultiModalConfig as input.
+    Note, the text input is also adjusted to abide by aphrodite contract.
+    The text output is sanitized to be able to compare with hf.
+    """
+    run_video_test(
+        hf_runner,
+        aphrodite_runner,
+        video_assets,
+        model,
+        size_factors=size_factors,
+        dtype=dtype,
+        max_tokens=max_tokens,
+        num_logprobs=num_logprobs,
+        num_frames=num_frames,
+        tensor_parallel_size=1,
+    )
+
+
+@pytest.mark.skipif(
+    transformers.__version__ < "4.45",
+    reason="Waiting for next transformers release",
+)
+@pytest.mark.parametrize("model", models)
+@pytest.mark.parametrize(
+    "sizes",
+    [[(1669, 2560), (2560, 1669), (183, 488), (488, 183)]],
+)
+@pytest.mark.parametrize("dtype", ["half"])
+@pytest.mark.parametrize("max_tokens", [128])
+@pytest.mark.parametrize("num_logprobs", [5])
+@pytest.mark.parametrize("num_frames", [16])
+def test_models_fixed_sizes(
+    hf_runner,
+    aphrodite_runner,
+    video_assets,
+    model,
+    sizes,
+    dtype,
+    max_tokens,
+    num_logprobs,
+    num_frames,
+) -> None:
+    run_video_test(
+        hf_runner,
+        aphrodite_runner,
+        video_assets,
+        model,
+        sizes=sizes,
+        dtype=dtype,
+        max_tokens=max_tokens,
+        num_logprobs=num_logprobs,
+        num_frames=num_frames,
+        tensor_parallel_size=1,
+    )
+
+
+# Image test
+_LIMIT_IMAGE_PER_PROMPT = 4
+
+
+def run_image_test(
+    hf_runner: Type[HfRunner],
+    aphrodite_runner: Type[AphroditeRunner],
+    inputs: List[Tuple[List[str], PromptImageInput]],
+    model: str,
+    dtype: str,
+    max_tokens: int,
+    num_logprobs: int,
+    tensor_parallel_size: int,
+    distributed_executor_backend: Optional[str] = None,
+):
+    torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
+    # max_model_len should be greater than image_feature_size
+    with aphrodite_runner(
+        model,
+        dtype=dtype,
+        max_model_len=32768,
+        tensor_parallel_size=tensor_parallel_size,
+        distributed_executor_backend=distributed_executor_backend,
+        enforce_eager=True,
+        limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT},
+    ) as aphrodite_model:
+        aphrodite_outputs_per_image = [
+            aphrodite_model.generate_greedy_logprobs(
+                prompts, max_tokens, num_logprobs=num_logprobs, images=images
+            )
+            for prompts, images in inputs
+        ]
+
+    def process(hf_inputs: BatchEncoding):
+        hf_inputs["pixel_values"] = hf_inputs["pixel_values"].to(torch_dtype)  # type: ignore
+        return hf_inputs
+
+    with hf_runner(
+        model,
+        dtype=dtype,
+        postprocess_inputs=process,
+        auto_cls=AutoModelForVision2Seq,
+    ) as hf_model:
+        hf_outputs_per_image = [
+            hf_model.generate_greedy_logprobs_limit(
+                prompts, max_tokens, num_logprobs=num_logprobs, images=images
+            )
+            for prompts, images in inputs
+        ]
+    for hf_outputs, aphrodite_outputs in zip(
+        hf_outputs_per_image, aphrodite_outputs_per_image
+    ):
+        # TODO: Check whether using original CLIPVisionModel can improve
+        # consistency against HF
+        check_logprobs_close(
+            outputs_0_lst=hf_outputs,
+            outputs_1_lst=[
+                aphrodite_to_hf_output(aphrodite_output, model)
+                for aphrodite_output in aphrodite_outputs
+            ],
+            name_0="hf",
+            name_1="aphrodite",
+        )
+
+
+@pytest.mark.skipif(
+    transformers.__version__ < "4.45",
+    reason="Waiting for next transformers release",
+)
+@pytest.mark.parametrize("model", models)
+@pytest.mark.parametrize("dtype", ["half"])
+@pytest.mark.parametrize("max_tokens", [128])
+@pytest.mark.parametrize("num_logprobs", [5])
+def test_models_multiple_image_inputs(
+    hf_runner, aphrodite_runner, image_assets, model, dtype, max_tokens, num_logprobs
+) -> None:
+    stop_sign = image_assets[0].pil_image
+    cherry_blossom = image_assets[1].pil_image
+    inputs = [
+        (
+            [
+                "<|im_start|>user <image><image>\nDescribe 2 images. \
+                <|im_end|><|im_start|>assistant\n",
+                "<|im_start|>user <image><image>\nDescribe 2 images. \
+                <|im_end|><|im_start|>assistant\n",
+                "<|im_start|>user <image><image><image><image>\nDescribe 4 images. \
+                <|im_end|><|im_start|>assistant\n",
+                "<|im_start|>user <image>\nWhat is the season? \
+                <|im_end|><|im_start|>assistant\n",
+            ],
+            [
+                [stop_sign, cherry_blossom],
+                # Images with different sizes and aspect-ratios
+                [
+                    rescale_image_size(stop_sign, 0.1),
+                    stop_sign,
+                ],
+                [
+                    stop_sign,
+                    rescale_image_size(stop_sign, 0.25),
+                    cherry_blossom.resize((183, 488)),
+                    cherry_blossom.resize((488, 183)),
+                ],
+                cherry_blossom,
+            ],
+        )
+    ]
+    run_image_test(
+        hf_runner,
+        aphrodite_runner,
+        inputs,
+        model,
+        dtype=dtype,
+        max_tokens=max_tokens,
+        num_logprobs=num_logprobs,
+        tensor_parallel_size=1,
+    )

+ 2 - 1
tests/models/test_registry.py

@@ -6,7 +6,8 @@ from aphrodite.modeling.models import _MODELS, ModelRegistry
 
 
 @pytest.mark.parametrize("model_cls", _MODELS)
 @pytest.mark.parametrize("model_cls", _MODELS)
 def test_registry_imports(model_cls):
 def test_registry_imports(model_cls):
-    if (model_cls == "Qwen2VLForConditionalGeneration"
+    if (model_cls in ("LlavaOnevisionForConditionalGeneration",
+                      "Qwen2VLForConditionalGeneration")
             and transformers.__version__ < "4.45"):
             and transformers.__version__ < "4.45"):
         pytest.skip("Waiting for next transformers release")
         pytest.skip("Waiting for next transformers release")