Kaynağa Gözat

vlm: add support for Qwen2-VL model (#1015)

AlpinDale 2 ay önce
ebeveyn
işleme
411ac4f405

+ 7 - 4
aphrodite/assets/video.py

@@ -1,6 +1,6 @@
 from dataclasses import dataclass
 from functools import lru_cache
-from typing import List, Literal
+from typing import List, Optional
 
 import numpy as np
 import numpy.typing as npt
@@ -68,17 +68,20 @@ def video_to_pil_images_list(
 
 @dataclass(frozen=True)
 class VideoAsset:
-    name: Literal["sample_demo_1.mp4"]
+    name: str = "sample_demo_1.mp4"
     num_frames: int = -1
+    local_path: Optional[str] = None
 
     @property
     def pil_images(self) -> List[Image.Image]:
-        video_path = download_video_asset(self.name)
+        video_path = (self.local_path if self.local_path else
+                      download_video_asset(self.name))
         ret = video_to_pil_images_list(video_path, self.num_frames)
         return ret
 
     @property
     def np_ndarrays(self) -> List[npt.NDArray]:
-        video_path = download_video_asset(self.name)
+        video_path = (self.local_path if self.local_path else
+                      download_video_asset(self.name))
         ret = video_to_ndarrays(video_path, self.num_frames)
         return ret

+ 5 - 2
aphrodite/common/config.py

@@ -1911,8 +1911,11 @@ def _get_and_verify_max_len(
                     "Disabling sliding window is not supported for models "
                     "with rope_scaling. Please raise an issue so we can "
                     "investigate.")
-            assert "factor" in rope_scaling
-            scaling_factor = rope_scaling["factor"]
+            if rope_type == "mrope":
+                scaling_factor = 1
+            else:
+                assert "factor" in rope_scaling
+                scaling_factor = rope_scaling["factor"]
             if rope_type == "yarn":
                 derived_max_model_len = rope_scaling[
                     "original_max_position_embeddings"]

+ 11 - 0
aphrodite/common/sequence.py

@@ -155,6 +155,9 @@ class SequenceData(msgspec.Struct,
     # is called.
     _new_appended_tokens: List[int] = msgspec.field(default_factory=list)
 
+    # It is used to compute mrope_position_ids.
+    _mrope_position_delta: Optional[int] = None
+
     def __post_init__(self) -> None:
         assert self._prompt_token_ids.typecode == "l"
         assert self._output_token_ids.typecode == "l"
@@ -209,6 +212,14 @@ class SequenceData(msgspec.Struct,
         assert isinstance(self._output_token_ids, array)
         return self._output_token_ids
 
+    @property
+    def mrope_position_delta(self) -> Optional[int]:
+        return self._mrope_position_delta
+
+    @mrope_position_delta.setter
+    def mrope_position_delta(self, new_mrope_position_delta):
+        self._mrope_position_delta = new_mrope_position_delta
+
     def append_token_id(self, token_id: int, logprob: float) -> None:
         self._output_token_ids.append(token_id)
         self._new_appended_tokens.append(token_id)

+ 7 - 1
aphrodite/endpoints/chat_utils.py

@@ -101,7 +101,7 @@ class ConversationMessage(TypedDict, total=False):
     """The tool calls generated by the model, such as function calls."""
 
 
-ModalityStr = Literal["image", "audio"]
+ModalityStr = Literal["image", "audio", "video"]
 _T = TypeVar("_T")
 
 
@@ -148,12 +148,18 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
                                               hf_config.image_token_index)
             if model_type in ("chameleon", "internvl_chat"):
                 return "<image>"
+            if model_type == "qwen2_vl":
+                return "<|vision_start|><|image_pad|><|vision_end|>"
 
             raise TypeError(f"Unknown model type: {model_type}")
         elif modality == "audio":
             if model_type == "ultravox":
                 return "<|reserved_special_token_0|>"
             raise TypeError(f"Unknown model type: {model_type}")
+        elif modality == "video":
+            if model_type == "qwen2_vl":
+                return "<|vision_start|><|video_pad|><|vision_end|>"
+            raise TypeError(f"Unknown model type: {model_type}")
         else:
             raise TypeError(f"Unknown modality: {modality}")
 

+ 210 - 84
aphrodite/modeling/layers/rotary_embedding.py

@@ -29,7 +29,6 @@ import torch
 import torch.nn as nn
 
 from aphrodite.modeling._custom_op import CustomOp
-from aphrodite.platforms import current_platform
 
 
 def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -45,26 +44,33 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
     return x.flatten(-2)
 
 
-# for TPUs
 def _apply_rotary_emb(
     x: torch.Tensor,
     cos: torch.Tensor,
     sin: torch.Tensor,
+    is_neox_style: bool,
 ) -> torch.Tensor:
     """
     Args:
         x: [num_tokens, num_heads, head_size]
         cos: [num_tokens, head_size // 2]
         sin: [num_tokens, head_size // 2]
+        is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
+            positional embeddings.
     """
-    orig_dtype = x.dtype
-    x = x.float()
-    x1, x2 = torch.chunk(x, 2, dim=-1)
-    cos = cos.unsqueeze(-2)
-    sin = sin.unsqueeze(-2)
+    cos = cos.unsqueeze(-2).to(x.dtype)
+    sin = sin.unsqueeze(-2).to(x.dtype)
+    if is_neox_style:
+        x1, x2 = torch.chunk(x, 2, dim=-1)
+    else:
+        x1 = x[..., ::2]
+        x2 = x[..., 1::2]
     o1 = x1 * cos - x2 * sin
     o2 = x2 * cos + x1 * sin
-    return torch.cat((o1, o2), dim=-1).to(orig_dtype)
+    if is_neox_style:
+        return torch.cat((o1, o2), dim=-1)
+    else:
+        return torch.stack((o1, o2), dim=-1).flatten(-2)
 
 
 class RotaryEmbedding(CustomOp):
@@ -89,16 +95,11 @@ class RotaryEmbedding(CustomOp):
 
         cache = self._compute_cos_sin_cache()
         cache = cache.to(dtype)
+        self.cos_sin_cache: torch.Tensor
         self.register_buffer("cos_sin_cache", cache, persistent=False)
-        self.use_native2 = current_platform.is_tpu() and is_neox_style
 
     def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
         """Compute the inverse frequency."""
-        # NOTE: The HF implementation uses `torch.arange(...).float()`.
-        # However, we use `torch.arange(..., dtype=torch.float)` instead to
-        # avoid numerical issues with large base values (e.g., 10000000).
-        # This may cause a slight numerical difference between the HF
-        # implementation and ours.
         # NOTE: To exactly match the HF implementation, we need to
         # use CPU to compute the cache and then move it to GPU. However, we
         # create the cache on GPU for faster initialization. This may cause
@@ -125,58 +126,7 @@ class RotaryEmbedding(CustomOp):
         key: torch.Tensor,
         offsets: Optional[torch.Tensor] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """A PyTorch-native implementation equivalent to forward().
-
-        This method mimics the implementation of the custom CUDA kernel
-        used in `forward_cuda()`.
-        """
-        query = query.view(*query.shape[:-1], -1, self.head_size)
-        key = key.view(*key.shape[:-1], -1, self.head_size)
-
-        query_rot = query[..., :self.rotary_dim]
-        key_rot = key[..., :self.rotary_dim]
-        if self.rotary_dim < self.head_size:
-            query_pass = query[..., self.rotary_dim:]
-            key_pass = key[..., self.rotary_dim:]
-
-        self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
-            positions.device, dtype=query.dtype)
-        cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
-                                     if offsets is not None else positions]
-        cos, sin = cos_sin.chunk(2, dim=-1)
-        if self.is_neox_style:
-            # NOTE: Here we assume that the positions tensor has the
-            # shape [batch_size, seq_len].
-            cos = cos.repeat(1, 1, 2).unsqueeze(-2)
-            sin = sin.repeat(1, 1, 2).unsqueeze(-2)
-        else:
-            cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
-            sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
-
-        rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
-        query_rot = query_rot * cos + rotate_fn(query_rot) * sin
-        key_rot = key_rot * cos + rotate_fn(key_rot) * sin
-
-        if self.rotary_dim < self.head_size:
-            query = torch.cat((query_rot, query_pass), dim=-1)
-            key = torch.cat((key_rot, key_pass), dim=-1)
-        else:
-            query = query_rot
-            key = key_rot
-        query = query.flatten(-2)
-        key = key.flatten(-2)
-        return query, key
-
-    def forward_native2(
-        self,
-        positions: torch.Tensor,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        offsets: Optional[torch.Tensor] = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Another PyTorch-native implementation of forward().
-        This method might perform better than `forward_native()` when compiled.
-        """
+        """A PyTorch-native implementation of forward()."""
         if offsets is not None:
             positions = positions + offsets
         positions = positions.flatten()
@@ -188,14 +138,14 @@ class RotaryEmbedding(CustomOp):
         query = query.view(num_tokens, -1, self.head_size)
         query_rot = query[..., :self.rotary_dim]
         query_pass = query[..., self.rotary_dim:]
-        query_rot = _apply_rotary_emb(query_rot, cos, sin)
+        query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
         query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
 
         key_shape = key.shape
         key = key.view(num_tokens, -1, self.head_size)
         key_rot = key[..., :self.rotary_dim]
         key_pass = key[..., self.rotary_dim:]
-        key_rot = _apply_rotary_emb(key_rot, cos, sin)
+        key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
         key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
         return query, key
 
@@ -208,7 +158,7 @@ class RotaryEmbedding(CustomOp):
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         from aphrodite import _custom_ops as ops
 
-        self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
+        self.cos_sin_cache = self.cos_sin_cache.to(query.device,
                                                    dtype=query.dtype)
         # ops.rotary_embedding()/batched_rotary_embedding()
         # are in-place operations that update the query and key tensors.
@@ -245,17 +195,6 @@ class RotaryEmbedding(CustomOp):
                                  self.cos_sin_cache, self.is_neox_style)
         return query, key
 
-    def forward_tpu(
-        self,
-        positions: torch.Tensor,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        offsets: Optional[torch.Tensor] = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        forward_fn = (self.forward_native2
-                      if self.use_native2 else self.forward_native)
-        return forward_fn(positions, query, key, offsets)
-
     def extra_repr(self) -> str:
         s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
         s += f", max_position_embeddings={self.max_position_embeddings}"
@@ -541,6 +480,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
             short_mscale = scaling_factor
         if long_mscale is None:
             long_mscale = scaling_factor
+
         self.short_mscale = short_mscale
         self.long_mscale = long_mscale
 
@@ -738,6 +678,7 @@ class GemmaRotaryEmbedding(RotaryEmbedding):
 
 
 class Llama3RotaryEmbedding(RotaryEmbedding):
+
     def __init__(
         self,
         head_size: int,
@@ -762,6 +703,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
         inv_freqs = super()._compute_inv_freq(base)
         low_freq_wavelen = self.orig_max_position / self.low_freq_factor
         high_freq_wavelen = self.orig_max_position / self.high_freq_factor
+
         wave_len = 2 * math.pi / inv_freqs
         if self.low_freq_factor != self.high_freq_factor:
             smooth = (self.orig_max_position / wave_len - self.low_freq_factor
@@ -781,6 +723,179 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
         return new_freqs
 
 
+class MRotaryEmbedding(RotaryEmbedding):
+    """Rotary Embedding with Multimodal Sections."""
+
+    def __init__(
+        self,
+        head_size: int,
+        rotary_dim: int,
+        max_position_embeddings: int,
+        base: int,
+        is_neox_style: bool,
+        dtype: torch.dtype,
+        mrope_section: Optional[List[int]] = None,
+    ) -> None:
+        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
+                         is_neox_style, dtype)
+
+        self.mrope_section = mrope_section
+        if self.mrope_section:
+            assert sum(self.mrope_section) == rotary_dim // 2
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        query: torch.Tensor,
+        key: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """PyTorch-native implementation equivalent to forward().
+
+        Args:
+            positions:
+                [num_tokens,] (text only) or
+                [3, num_tokens] (T/H/W positions with multimodal inputs)
+            query: [num_tokens, num_heads * head_size]
+            key: [num_tokens, num_kv_heads * head_size]
+        """
+        assert positions.ndim == 1 or positions.ndim == 2
+
+        num_tokens = positions.shape[-1]
+        cos_sin = self.cos_sin_cache[positions]
+        cos, sin = cos_sin.chunk(2, dim=-1)
+        if positions.ndim == 2:
+            assert self.mrope_section
+
+            cos = torch.cat([
+                m[i]
+                for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
+            ],
+                            dim=-1)
+            sin = torch.cat([
+                m[i]
+                for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
+            ],
+                            dim=-1)
+
+        query_shape = query.shape
+        query = query.view(num_tokens, -1, self.head_size)
+        query_rot = query[..., :self.rotary_dim]
+        query_pass = query[..., self.rotary_dim:]
+        query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
+        query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
+
+        key_shape = key.shape
+        key = key.view(num_tokens, -1, self.head_size)
+        key_rot = key[..., :self.rotary_dim]
+        key_pass = key[..., self.rotary_dim:]
+        key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
+        key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
+        return query, key
+
+    @staticmethod
+    def get_input_positions(
+        input_tokens: List[int],
+        image_grid_thw: Union[List[List[int]], torch.Tensor],
+        video_grid_thw: Union[List[List[int]], torch.Tensor],
+        image_token_id: int,
+        video_token_id: int,
+        vision_start_token_id: int,
+        vision_end_token_id: int,
+        spatial_merge_size: int,
+        context_len: int = 0,
+    ) -> Tuple[List[List[int]], int]:
+        """Get mrope input positions and delta value."""
+
+        if isinstance(image_grid_thw, torch.Tensor):
+            image_grid_thw = image_grid_thw.tolist()
+        if isinstance(video_grid_thw, torch.Tensor):
+            video_grid_thw = video_grid_thw.tolist()
+
+        input_tokens_tensor = torch.tensor(input_tokens)
+        vision_start_indices = torch.argwhere(
+            input_tokens_tensor == vision_start_token_id).squeeze(1)
+        vision_tokens = input_tokens_tensor[vision_start_indices + 1]
+        image_nums = (vision_tokens == image_token_id).sum()
+        video_nums = (vision_tokens == video_token_id).sum()
+        llm_pos_ids_list: list = []
+
+        st = 0
+        remain_images, remain_videos = image_nums, video_nums
+
+        image_index, video_index = 0, 0
+        for _ in range(image_nums + video_nums):
+            if image_token_id in input_tokens and remain_images > 0:
+                ed_image = input_tokens.index(image_token_id, st)
+            else:
+                ed_image = len(input_tokens) + 1
+            if video_token_id in input_tokens and remain_videos > 0:
+                ed_video = input_tokens.index(video_token_id, st)
+            else:
+                ed_video = len(input_tokens) + 1
+            if ed_image < ed_video:
+                t, h, w = (
+                    image_grid_thw[image_index][0],
+                    image_grid_thw[image_index][1],
+                    image_grid_thw[image_index][2],
+                )
+                image_index += 1
+                remain_images -= 1
+                ed = ed_image
+            else:
+                t, h, w = (
+                    video_grid_thw[video_index][0],
+                    video_grid_thw[video_index][1],
+                    video_grid_thw[video_index][2],
+                )
+                video_index += 1
+                remain_videos -= 1
+                ed = ed_video
+            llm_grid_t, llm_grid_h, llm_grid_w = \
+                t, h // spatial_merge_size, w // spatial_merge_size
+            text_len = ed - st
+
+            st_idx = llm_pos_ids_list[-1].max() + 1 if len(
+                llm_pos_ids_list) > 0 else 0
+            llm_pos_ids_list.append(
+                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+            t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
+                -1, llm_grid_h * llm_grid_w).flatten()
+            h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
+                llm_grid_t, -1, llm_grid_w).flatten()
+            w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
+                llm_grid_t, llm_grid_h, -1).flatten()
+            llm_pos_ids_list.append(
+                torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
+            st = ed + llm_grid_t * llm_grid_h * llm_grid_w
+
+        if st < len(input_tokens):
+            st_idx = llm_pos_ids_list[-1].max() + 1 if len(
+                llm_pos_ids_list) > 0 else 0
+            text_len = len(input_tokens) - st
+            llm_pos_ids_list.append(
+                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
+        llm_positions = llm_positions[:, context_len:]
+        mrope_position_delta = (llm_positions.max() + 1 -
+                                len(input_tokens)).item()
+
+        return llm_positions.tolist(), mrope_position_delta
+
+    @staticmethod
+    def get_next_input_positions(
+        mrope_position_delta: int,
+        context_len: int,
+        seq_len: int,
+    ) -> List[List[int]]:
+        return [
+            list(
+                range(context_len + mrope_position_delta,
+                      seq_len + mrope_position_delta)) for _ in range(3)
+        ]
+
+
 _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
 
 
@@ -792,7 +907,7 @@ def get_rope(
     is_neox_style: bool = True,
     rope_scaling: Optional[Dict[str, Any]] = None,
     dtype: Optional[torch.dtype] = None,
-    rotary_percent: float = 1.0,
+    partial_rotary_factor: float = 1.0,
 ) -> RotaryEmbedding:
     if dtype is None:
         dtype = torch.get_default_dtype()
@@ -805,12 +920,13 @@ def get_rope(
         rope_scaling_args = tuple(rope_scaling_tuple.items())
     else:
         rope_scaling_args = None
-    if rotary_percent < 1.0:
-        rotary_dim = int(rotary_dim * rotary_percent)
+    if partial_rotary_factor < 1.0:
+        rotary_dim = int(rotary_dim * partial_rotary_factor)
     key = (head_size, rotary_dim, max_position, base, is_neox_style,
            rope_scaling_args, dtype)
     if key in _ROPE_DICT:
         return _ROPE_DICT[key]
+
     if rope_scaling is None:
         rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
                                      is_neox_style, dtype)
@@ -820,7 +936,7 @@ def get_rope(
         # The correct one should be "longrope" but keep "su" here
         # for backward compatible
         if scaling_type not in {"su", "longrope"}:
-            scaling_factor = rope_scaling["factor"]
+            scaling_factor = rope_scaling.get("factor", 1.0)
         if scaling_type == "llama3":
             low_freq_factor = rope_scaling["low_freq_factor"]
             high_freq_factor = rope_scaling["high_freq_factor"]
@@ -884,6 +1000,16 @@ def get_rope(
                 head_size, rotary_dim, max_position, original_max_position,
                 base, is_neox_style, dtype, short_factor, long_factor,
                 **extra_kwargs)
+        elif scaling_type == "mrope":
+            return MRotaryEmbedding(
+                head_size,
+                rotary_dim,
+                max_position,
+                base,
+                is_neox_style,
+                dtype,
+                mrope_section=rope_scaling["mrope_section"],
+            )
         else:
             raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
     _ROPE_DICT[key] = rotary_emb

+ 5 - 1
aphrodite/modeling/models/__init__.py

@@ -48,9 +48,10 @@ _GENERATION_MODELS = {
     "PhiForCausalLM": ("phi", "PhiForCausalLM"),
     "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
     "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
-    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
     "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
     "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
+    "Qwen2VLForConditionalGeneration":
+    ("qwen2_vl", "Qwen2VLForConditionalGeneration"),
     "RWForCausalLM": ("falcon", "FalconForCausalLM"),
     "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
     "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
@@ -91,6 +92,9 @@ _MULTIMODAL_MODELS = {
                                           "PaliGemmaForConditionalGeneration"),
     "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
     "UltravoxModel": ("ultravox", "UltravoxModel"),
+    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
+    "Qwen2VLForConditionalGeneration": ("qwen2_vl",
+                                        "Qwen2VLForConditionalGeneration"),
 }
 
 _CONDITIONAL_GENERATION_MODELS = {

+ 1 - 1
aphrodite/modeling/models/granite.py

@@ -26,6 +26,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
 
 import torch
 from torch import nn
+from transformers import GraniteConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
@@ -50,7 +51,6 @@ from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.quantization.base_config import QuantizationConfig
 from aphrodite.quantization.compressed_tensors.utils import (
     get_compressed_tensors_cache_scale)
-from aphrodite.transformers_utils.configs.granite import GraniteConfig
 
 from .interfaces import SupportsLoRA
 from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers

+ 1129 - 0
aphrodite/modeling/models/qwen2_vl.py

@@ -0,0 +1,1129 @@
+# coding=utf-8
+# Adapted from
+# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
+# Copyright 2024 The Qwen team.
+# Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The vLLM team.
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
+from array import array
+from functools import lru_cache, partial
+from typing import (Iterable, List, Mapping, Optional, Tuple, Type, TypedDict,
+                    Union)
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from loguru import logger
+from PIL import Image
+from transformers.image_utils import (get_image_size,
+                                      infer_channel_dimension_format,
+                                      to_numpy_array)
+from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
+    make_batched_images, make_batched_videos, smart_resize)
+
+import aphrodite.common.envs as envs
+from aphrodite.attention import AttentionMetadata
+from aphrodite.attention.selector import (_Backend, backend_name_to_enum,
+                                          get_global_forced_attn_backend)
+from aphrodite.common.config import CacheConfig, MultiModalConfig
+from aphrodite.common.logger import log_once
+from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
+                                       IntermediateTensors, SequenceData)
+from aphrodite.distributed import parallel_state
+from aphrodite.distributed import utils as dist_utils
+from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from aphrodite.modeling.layers.activation import QuickGELU
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
+from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.models.interfaces import SupportsMultiModal
+from aphrodite.modeling.models.qwen2 import Qwen2Model
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
+                                  MultiModalInputs)
+from aphrodite.multimodal.base import MultiModalData
+from aphrodite.multimodal.image import cached_get_image_processor
+from aphrodite.platforms import current_platform
+from aphrodite.quantization import QuantizationConfig
+from aphrodite.transformers_utils.configs import (Qwen2VLConfig,
+                                                  Qwen2VLVisionConfig)
+from aphrodite.transformers_utils.processor import get_processor
+
+
+# === Vision Inputs === #
+class Qwen2VLImageInputs(TypedDict):
+    pixel_values: torch.Tensor
+    """Shape: 
+    `(num_patches, num_channels * patch_size * patch_size)`
+    """
+    image_grid_thw: torch.Tensor
+    """Shape: `(num_images, 3)`
+    
+    This should be in `(grid_t, grid_h, grid_w)` format.
+    """
+
+
+class Qwen2VLVideoInputs(TypedDict):
+    pixel_values_videos: torch.Tensor
+    """Shape: 
+    `(num_patches, 
+      num_channels * temporal_patch_size * patch_size * patch_size)`
+    """
+    video_grid_thw: torch.Tensor
+    """Shape: `(num_videos, 3)`
+    
+    This should be in `(grid_t, grid_h, grid_w)` format.
+    """
+
+
+# === Vision Encoder === #
+class Qwen2VisionMLP(nn.Module):
+    def __init__(
+        self,
+        in_features: int,
+        hidden_features: int = None,
+        act_layer: Type[nn.Module] = QuickGELU,
+        quant_config: Optional[QuantizationConfig] = None,
+    ):
+        super().__init__()
+        self.fc1 = ColumnParallelLinear(
+            in_features, hidden_features, quant_config=quant_config
+        )
+        self.act = act_layer()
+        self.fc2 = RowParallelLinear(
+            hidden_features, in_features, quant_config=quant_config
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x_parallel, _ = self.fc1(x)
+        x_parallel = self.act(x_parallel)
+        x, _ = self.fc2(x_parallel)
+        return x
+
+
+def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
+    if not interleaved:
+        x1, x2 = x.chunk(2, dim=-1)
+        return torch.cat((-x2, x1), dim=-1)
+    else:
+        x1, x2 = x[..., ::2], x[..., 1::2]
+        return rearrange(
+            torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
+        )
+
+
+def apply_rotary_emb_torch(
+    x: torch.Tensor,
+    cos: torch.Tensor,
+    sin: torch.Tensor,
+    interleaved: bool = False,
+) -> torch.Tensor:
+    """
+    x: (batch_size, seqlen, nheads, headdim)
+    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
+    """
+    ro_dim = cos.shape[-1] * 2
+    assert ro_dim <= x.shape[-1]
+    cos = repeat(
+        cos,
+        "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)",
+    )
+    sin = repeat(
+        sin,
+        "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)",
+    )
+    return torch.cat(
+        [
+            x[..., :ro_dim] * cos
+            + rotate_half(x[..., :ro_dim], interleaved) * sin,
+            x[..., ro_dim:],
+        ],
+        dim=-1,
+    )
+
+
+def apply_rotary_pos_emb_vision(
+    t: torch.Tensor, freqs: torch.Tensor
+) -> torch.Tensor:
+    t_ = t.float()
+    cos = freqs.cos()
+    sin = freqs.sin()
+    output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
+    return output
+
+
+class Qwen2VisionAttention(nn.Module):
+    def __init__(
+        self,
+        embed_dim: Optional[int] = None,
+        num_heads: Optional[int] = None,
+        projection_size: Optional[int] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__()
+        # Per attention head and per partition values.
+        world_size = parallel_state.get_tensor_model_parallel_world_size()
+        self.hidden_size_per_attention_head = dist_utils.divide(
+            projection_size, num_heads
+        )
+        self.num_attention_heads_per_partition = dist_utils.divide(
+            num_heads, world_size
+        )
+        self.qkv = ColumnParallelLinear(
+            input_size=embed_dim,
+            output_size=3 * projection_size,
+            quant_config=quant_config,
+        )
+        self.proj = RowParallelLinear(
+            input_size=projection_size,
+            output_size=embed_dim,
+            quant_config=quant_config,
+        )
+        # Detect attention implementation.
+        selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
+        if selected_backend is None:
+            backend_by_env_var: Optional[str] = envs.APHRODITE_ATTENTION_BACKEND
+            if backend_by_env_var is not None:
+                selected_backend = backend_name_to_enum(backend_by_env_var)
+        if selected_backend is None:
+            # For Volta and Turing GPUs, use xformers instead.
+            device_available = current_platform.get_device_capability()[0] >= 8
+            if device_available:
+                from transformers.utils import is_flash_attn_2_available
+
+                if is_flash_attn_2_available():
+                    self._use_flash_attn = True
+                else:
+                    log_once(
+                    level="WARNING",
+                    message=
+                        "Current Qwen2-VL implementation has a bug with "
+                        "`aphrodite-flash-attn` inside vision module, so we use"
+                        " xformers backend instead. You can run `pip install "
+                        "flash-attn to use flash-attention backend."
+                    )
+                    self._use_flash_attn = False
+            else:
+                self._use_flash_attn = False
+        else:
+            if selected_backend == _Backend.FLASH_ATTN:
+                self._use_flash_attn = True
+            elif selected_backend == _Backend.XFORMERS:
+                self._use_flash_attn = False
+            else:
+                raise RuntimeError(
+                    f"Qwen2-VL does not support {selected_backend} backend now."
+                )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        cu_seqlens: torch.Tensor,
+        rotary_pos_emb: torch.Tensor = None,
+    ) -> torch.Tensor:
+        # [s, b, c] --> [s, b, head * 3 * head_dim]
+        x, _ = self.qkv(x)
+        # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
+        new_x_shape = x.size()[:-1] + (
+            self.num_attention_heads_per_partition,
+            3 * self.hidden_size_per_attention_head,
+        )
+        x = x.view(*new_x_shape)
+        # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
+        q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
+        batch_size = q.shape[1]
+        q, k, v = [
+            rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
+        ]
+        if rotary_pos_emb is not None:
+            q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
+            k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
+        if self._use_flash_attn:
+            # from aphrodite_flash_attn.flash_attn_interface import (
+            #   flash_attn_varlen_func)
+            from flash_attn import flash_attn_varlen_func
+
+            q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
+            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
+            output = flash_attn_varlen_func(
+                q,
+                k,
+                v,
+                cu_seqlens_q=cu_seqlens,
+                cu_seqlens_k=cu_seqlens,
+                max_seqlen_q=max_seqlen,
+                max_seqlen_k=max_seqlen,
+                dropout_p=0,
+                causal=False,
+            )
+            context_layer = rearrange(
+                output, "(b s) ... -> b s ...", b=batch_size
+            )
+        else:
+            from xformers import ops as xops
+            from xformers.ops.fmha.attn_bias import BlockDiagonalMask
+
+            seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+            attn_bias = BlockDiagonalMask.from_seqlens(
+                q_seqlen=seqlens, kv_seqlen=None
+            )
+            context_layer = xops.memory_efficient_attention_forward(
+                q, k, v, attn_bias=attn_bias, p=0, scale=None
+            )
+        context_layer = rearrange(
+            context_layer, "b s h d -> s b (h d)"
+        ).contiguous()
+        output, _ = self.proj(context_layer)
+        return output
+
+
+class Qwen2VisionBlock(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        num_heads: int,
+        mlp_ratio: float,
+        act_layer: Type[nn.Module] = QuickGELU,
+        norm_layer: Type[nn.Module] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__()
+        if norm_layer is None:
+            norm_layer = partial(nn.LayerNorm, eps=1e-6)
+        self.norm1 = norm_layer(dim)
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.attn = Qwen2VisionAttention(
+            embed_dim=dim,
+            num_heads=num_heads,
+            projection_size=dim,
+            quant_config=quant_config,
+        )
+        self.mlp = Qwen2VisionMLP(
+            dim, mlp_hidden_dim, act_layer=act_layer, quant_config=quant_config
+        )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        cu_seqlens: torch.Tensor,
+        rotary_pos_emb: torch.Tensor,
+    ) -> torch.Tensor:
+        x = x + self.attn(
+            self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
+        )
+        x = x + self.mlp(self.norm2(x))
+        return x
+
+
+class Qwen2VisionPatchEmbed(nn.Module):
+    def __init__(
+        self,
+        patch_size: int = 14,
+        temporal_patch_size: int = 2,
+        in_chans: int = 3,
+        embed_dim: int = 1152,
+    ) -> None:
+        super().__init__()
+        self.patch_size = patch_size
+        self.temporal_patch_size = temporal_patch_size
+        self.embed_dim = embed_dim
+        kernel_size = [temporal_patch_size, patch_size, patch_size]
+        self.proj = nn.Conv3d(
+            in_chans,
+            embed_dim,
+            kernel_size=kernel_size,
+            stride=kernel_size,
+            bias=False,
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        L, C = x.shape
+        x = x.view(
+            L, -1, self.temporal_patch_size, self.patch_size, self.patch_size
+        )
+        x = self.proj(x).view(L, self.embed_dim)
+        return x
+
+
+class Qwen2VisionPatchMerger(nn.Module):
+    def __init__(
+        self,
+        d_model: int,
+        context_dim: int,
+        norm_layer: Type[nn.Module] = None,
+        spatial_merge_size: int = 2,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.hidden_size = context_dim * (spatial_merge_size**2)
+        if norm_layer is None:
+            norm_layer = partial(nn.LayerNorm, eps=1e-6)
+        self.ln_q = norm_layer(context_dim)
+        self.mlp = nn.ModuleList(
+            [
+                ColumnParallelLinear(
+                    self.hidden_size,
+                    self.hidden_size,
+                    bias=True,
+                    quant_config=quant_config,
+                ),
+                nn.GELU(),
+                RowParallelLinear(
+                    self.hidden_size,
+                    d_model,
+                    bias=True,
+                    quant_config=quant_config,
+                ),
+            ]
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.ln_q(x)
+        x = x.view(-1, self.hidden_size)
+        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
+        x_parallel, _ = mlp_fc1(x)
+        x_parallel = mlp_act(x_parallel)
+        out, _ = mlp_fc2(x_parallel)
+        return out
+
+
+class Qwen2VisionRotaryEmbedding(nn.Module):
+    def __init__(self, dim: int, theta: float = 10000.0) -> None:
+        super().__init__()
+        self.dim = dim
+        self.theta = theta
+        inv_freq = 1.0 / (
+            theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)
+        )
+        self.register_buffer("inv_freq", inv_freq, persistent=False)
+        self._seq_len_cached = 0
+        self._freqs_cached = None
+
+    def update_freqs_cache(self, seqlen: int) -> None:
+        if seqlen > self._seq_len_cached:
+            seqlen *= 2
+            self._seq_len_cached = seqlen
+            self.inv_freq = 1.0 / (
+                self.theta
+                ** (
+                    torch.arange(
+                        0,
+                        self.dim,
+                        2,
+                        dtype=torch.float,
+                        device=self.inv_freq.device,
+                    )
+                    / self.dim
+                )
+            )
+            seq = torch.arange(
+                seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
+            )
+            freqs = torch.outer(seq, self.inv_freq)
+            self._freqs_cached = freqs
+
+    def forward(self, seqlen: int) -> torch.Tensor:
+        self.update_freqs_cache(seqlen)
+        return self._freqs_cached[:seqlen]
+
+
+class Qwen2VisionTransformer(nn.Module):
+    def __init__(
+        self,
+        vision_config: Qwen2VLVisionConfig,
+        norm_eps: float = 1e-6,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__()
+        patch_size: int = vision_config.patch_size
+        temporal_patch_size: int = vision_config.temporal_patch_size
+        spatial_merge_size: int = vision_config.spatial_merge_size
+        in_chans: int = vision_config.in_chans
+        hidden_size: int = vision_config.hidden_size
+        embed_dim: int = vision_config.embed_dim
+        depth: int = vision_config.depth
+        num_heads: int = vision_config.num_heads
+        mlp_ratio: float = vision_config.mlp_ratio
+        self.spatial_merge_size = spatial_merge_size
+        self.patch_embed = Qwen2VisionPatchEmbed(
+            patch_size=patch_size,
+            temporal_patch_size=temporal_patch_size,
+            in_chans=in_chans,
+            embed_dim=embed_dim,
+        )
+        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
+        head_dim = embed_dim // num_heads
+        self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)
+        self.blocks = nn.ModuleList(
+            [
+                Qwen2VisionBlock(
+                    dim=embed_dim,
+                    num_heads=num_heads,
+                    mlp_ratio=mlp_ratio,
+                    norm_layer=norm_layer,
+                    quant_config=quant_config,
+                )
+                for _ in range(depth)
+            ]
+        )
+        self.merger = Qwen2VisionPatchMerger(
+            d_model=hidden_size,
+            context_dim=embed_dim,
+            norm_layer=norm_layer,
+            quant_config=quant_config,
+        )
+
+    @property
+    def dtype(self) -> torch.dtype:
+        return self.blocks[0].mlp.fc2.weight.dtype
+
+    @property
+    def device(self) -> torch.device:
+        return self.blocks[0].mlp.fc2.weight.device
+
+    def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
+        pos_ids = []
+        for t, h, w in grid_thw:
+            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+            hpos_ids = (
+                hpos_ids.reshape(
+                    h // self.spatial_merge_size,
+                    self.spatial_merge_size,
+                    w // self.spatial_merge_size,
+                    self.spatial_merge_size,
+                )
+                .permute(0, 2, 1, 3)
+                .flatten()
+            )
+            wpos_ids = (
+                wpos_ids.reshape(
+                    h // self.spatial_merge_size,
+                    self.spatial_merge_size,
+                    w // self.spatial_merge_size,
+                    self.spatial_merge_size,
+                )
+                .permute(0, 2, 1, 3)
+                .flatten()
+            )
+            pos_ids.append(
+                torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
+            )
+        pos_ids = torch.cat(pos_ids, dim=0)
+        max_grid_size = grid_thw[:, 1:].max()
+        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
+        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+        return rotary_pos_emb
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        grid_thw: torch.Tensor,
+    ) -> torch.Tensor:
+        # patchify
+        x = x.to(device=self.device, dtype=self.dtype)
+        x = self.patch_embed(x)
+        # compute position embedding
+        rotary_pos_emb = self.rot_pos_emb(grid_thw)
+        # compute cu_seqlens
+        cu_seqlens = torch.repeat_interleave(
+            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
+        ).cumsum(dim=0, dtype=torch.int32)
+        cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
+        # transformers
+        x = x.unsqueeze(1)
+        for blk in self.blocks:
+            x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
+        # adapter
+        x = self.merger(x)
+        return x
+
+
+# === Vision input helpers === #
+cached_get_processor = lru_cache(get_processor)
+
+
+def mm_input_mapper_for_qwen2_vl(
+    ctx: InputContext,
+    data: MultiModalData[object],
+    data_type_key: str,
+) -> MultiModalInputs:
+    """Input mapper for Qwen2-VL."""
+    model_config = ctx.model_config
+    image_processor = cached_get_image_processor(
+        model_config.model, trust_remote_code=model_config.trust_remote_code
+    )
+    if image_processor is None:
+        raise RuntimeError(
+            "No HuggingFace processor is available "
+            "to process the image object"
+        )
+    images = None
+    videos = None
+    if data_type_key == "image":
+        images = data
+    else:
+        assert data_type_key == "video"
+        videos = data
+    try:
+        batch_data = image_processor.preprocess(
+            images=images, videos=videos, return_tensors="pt"
+        ).data
+    except Exception:
+        logger.error("Failed to process image (%s)", data)
+        raise
+    return MultiModalInputs(batch_data)
+
+
+image_input_mapper_for_qwen2_vl = partial(
+    mm_input_mapper_for_qwen2_vl, data_type_key="image"
+)
+video_input_mapper_for_qwen2_vl = partial(
+    mm_input_mapper_for_qwen2_vl, data_type_key="video"
+)
+
+
+def _get_vision_info(
+    image_processor,
+    height: int,
+    width: int,
+    min_pixels: int,
+    max_pixels: int,
+    do_resize: bool = True,
+    data_type_key: str = "image",
+    mm_count: int = 1,
+):
+    """Get information (resized height / width and number of vision tokens)
+    of input image / video frame."""
+    if do_resize:
+        resized_height, resized_width = smart_resize(
+            height=height,
+            width=width,
+            factor=image_processor.patch_size * image_processor.merge_size,
+            min_pixels=min_pixels,
+            max_pixels=max_pixels,
+        )
+    else:
+        resized_height, resized_width = height, width
+    if data_type_key == "image":
+        grid_t = mm_count
+    else:
+        assert data_type_key == "video"
+        grid_t = max(mm_count // image_processor.temporal_patch_size, 1)
+    grid_h = resized_height // image_processor.patch_size
+    grid_w = resized_width // image_processor.patch_size
+    vision_tokens = grid_t * grid_h * grid_w
+    llm_num_vision_tokens = (
+        vision_tokens
+        // image_processor.merge_size
+        // image_processor.merge_size
+    )
+    return resized_height, resized_width, llm_num_vision_tokens
+
+
+def _get_max_image_info(
+    image_processor,
+    data_type_key: str = "image",
+    mm_count: int = 1,
+):
+    return _get_vision_info(
+        image_processor,
+        height=9999999,
+        width=9999999,
+        # Limit min / max pixels.
+        min_pixels=max(image_processor.min_pixels, 28 * 28),
+        max_pixels=min(image_processor.max_pixels, 1280 * 28 * 28),
+        data_type_key=data_type_key,
+        mm_count=mm_count,
+    )
+
+
+def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key: str) -> int:
+    image_processor = cached_get_image_processor(ctx.model_config.model)
+    (
+        max_resized_height,
+        max_resized_width,
+        max_llm_image_tokens,
+    ) = _get_max_image_info(
+        image_processor, data_type_key=data_type_key, mm_count=1
+    )
+    return max_llm_image_tokens
+
+
+get_max_qwen2_vl_image_tokens = partial(
+    get_max_qwen2_vl_mm_tokens, data_type_key="image"
+)
+get_max_qwen2_vl_video_tokens = partial(
+    get_max_qwen2_vl_mm_tokens, data_type_key="video"
+)
+
+
+def dummy_data_for_qwen2_vl(
+    ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
+) -> Tuple[SequenceData, Optional[MultiModalDataDict]]:
+    image_processor = cached_get_image_processor(ctx.model_config.model)
+    num_images = mm_counts["image"]
+    (
+        max_resized_height,
+        max_resized_width,
+        max_llm_image_tokens,
+    ) = _get_max_image_info(
+        image_processor, data_type_key="image", mm_count=num_images
+    )
+    if seq_len - max_llm_image_tokens - 2 < 0:
+        raise RuntimeError(
+            f"Qwen2-VL cannot process {num_images} images in a prompt, "
+            "please increase max_model_len or reduce image limit by "
+            "--limit-mm-per-prompt."
+        )
+    # Check video counts.
+    num_videos = mm_counts["video"]
+    (
+        max_resized_height,
+        max_resized_width,
+        max_llm_video_tokens,
+    ) = _get_max_image_info(
+        image_processor, data_type_key="video", mm_count=num_videos
+    )
+    if seq_len - max_llm_video_tokens - 2 < 0:
+        raise RuntimeError(
+            f"Qwen2-VL cannot process {num_images} videos in a prompt, "
+            "please increase max_model_len or reduce video limit by "
+            "--limit-mm-per-prompt."
+        )
+    hf_config = ctx.get_hf_config(Qwen2VLConfig)
+    token_ids = array(
+        APHRODITE_TOKEN_ID_ARRAY_TYPE, [hf_config.vision_start_token_id]
+    )
+    token_ids += (
+        array(APHRODITE_TOKEN_ID_ARRAY_TYPE, [hf_config.image_token_id])
+        * max_llm_image_tokens
+    )
+    token_ids += array(
+        APHRODITE_TOKEN_ID_ARRAY_TYPE, [hf_config.vision_end_token_id]
+    )
+    token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE, [0]) * (
+        seq_len - max_llm_image_tokens - 2
+    )
+    dummy_seqdata = SequenceData(token_ids)
+    dummy_image = Image.new(
+        "RGB", (max_resized_width, max_resized_height), color=0
+    )
+    return dummy_seqdata, {
+        "image": dummy_image if num_images == 1 else [dummy_image] * num_images
+    }
+
+
+def _get_llm_num_vision_tokens(
+    mm_inputs: list,
+    data_type_key: str,
+    image_processor,
+):
+    """Get number of vision tokens of multimodal inputs.
+    This method is derived from `transformers.models.qwen2_vl.
+    image_processing_qwen2_vl.Qwen2VLImageProcessor._preprocess`.
+    """
+    image = to_numpy_array(mm_inputs[0])
+    input_data_format = infer_channel_dimension_format(image)
+    height, width = get_image_size(image, channel_dim=input_data_format)
+    _, _, llm_num_vision_tokens = _get_vision_info(
+        image_processor,
+        height=height,
+        width=width,
+        min_pixels=image_processor.min_pixels,
+        max_pixels=image_processor.max_pixels,
+        do_resize=image_processor.do_resize,
+        data_type_key=data_type_key,
+        mm_count=len(mm_inputs),
+    )
+    return llm_num_vision_tokens
+
+
+def input_processor_for_qwen2_vl(
+    ctx: InputContext, llm_inputs: LLMInputs
+) -> LLMInputs:
+    multi_modal_data = llm_inputs.get("multi_modal_data", None)
+    if multi_modal_data is None:
+        return llm_inputs
+    image_inputs = multi_modal_data.get("image", None)
+    video_inputs = multi_modal_data.get("video", None)
+    processor = cached_get_processor(ctx.model_config.model)
+    image_processor = processor.image_processor
+    hf_config = ctx.get_hf_config(Qwen2VLConfig)
+    # To avoid redundant processing of vision objects (resize, rescale, etc.),
+    # we extract code of calculating number of vision tokens from
+    # `transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor`.
+    #
+    # The following code is equivalent to:
+    #    prompt = llm_inputs["prompt"]
+    #    inputs = processor(text=[prompt],
+    #                       images=image_inputs,
+    #                       videos=video_inputs,
+    #                       padding=True,
+    #                       return_tensors="pt")
+    #    prompt_token_ids = inputs["input_ids"][0].tolist()
+    prompt_token_ids = llm_inputs.get("prompt_token_ids", None)
+    if prompt_token_ids is None:
+        prompt = llm_inputs["prompt"]
+        prompt_token_ids = processor.tokenizer(
+            prompt,
+            padding=True,
+            return_tensors=None,
+        )["input_ids"]
+    # Expand image pad tokens.
+    if image_inputs is not None:
+        image_indices = [
+            idx
+            for idx, token in enumerate(prompt_token_ids)
+            if token == hf_config.image_token_id
+        ]
+        image_inputs = make_batched_images(image_inputs)
+        assert len(image_indices) == len(image_inputs)
+        prompt_token_ids_with_image = []
+        for image_cnt, image in enumerate(image_inputs):
+            num_image_tokens = _get_llm_num_vision_tokens(
+                [image],
+                data_type_key="image",
+                image_processor=image_processor,
+            )
+            if image_cnt == 0:
+                non_image_tokens = prompt_token_ids[: image_indices[image_cnt]]
+            else:
+                non_image_tokens = prompt_token_ids[
+                    image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
+                ]
+            prompt_token_ids_with_image.extend(non_image_tokens)
+            prompt_token_ids_with_image.extend(
+                hf_config.image_token_id for _ in range(num_image_tokens)
+            )
+        prompt_token_ids_with_image.extend(
+            prompt_token_ids[image_indices[-1] + 1 :]
+        )
+        prompt_token_ids = prompt_token_ids_with_image
+    # Expand video pad tokens.
+    if video_inputs is not None:
+        video_indices = [
+            idx
+            for idx, token in enumerate(prompt_token_ids)
+            if token == hf_config.video_token_id
+        ]
+        video_inputs = make_batched_videos(video_inputs)
+        assert len(video_indices) == len(video_inputs)
+        prompt_token_ids_with_video = []
+        for video_cnt, video in enumerate(video_inputs):
+            num_video_tokens = _get_llm_num_vision_tokens(
+                video,
+                data_type_key="video",
+                image_processor=image_processor,
+            )
+            if video_cnt == 0:
+                non_video_tokens = prompt_token_ids[: video_indices[video_cnt]]
+            else:
+                non_video_tokens = prompt_token_ids[
+                    video_indices[video_cnt - 1] + 1 : video_indices[video_cnt]
+                ]
+            prompt_token_ids_with_video.extend(non_video_tokens)
+            prompt_token_ids_with_video.extend(
+                hf_config.video_token_id for _ in range(num_video_tokens)
+            )
+        prompt_token_ids_with_video.extend(
+            prompt_token_ids[video_indices[-1] + 1 :]
+        )
+        prompt_token_ids = prompt_token_ids_with_video
+    return LLMInputs(
+        prompt_token_ids=prompt_token_ids,
+        prompt=llm_inputs["prompt"],
+        multi_modal_data=multi_modal_data,
+    )
+
+
+@MULTIMODAL_REGISTRY.register_image_input_mapper(
+    image_input_mapper_for_qwen2_vl
+)
+@MULTIMODAL_REGISTRY.register_input_mapper(
+    "video", video_input_mapper_for_qwen2_vl
+)
+@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen2_vl_image_tokens)
+@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
+    "video", get_max_qwen2_vl_video_tokens
+)
+@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl)
+@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl)
+class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
+    def __init__(
+        self,
+        config: Qwen2VLConfig,
+        multimodal_config: MultiModalConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__()
+        assert (
+            not cache_config.enable_prefix_caching
+        ), "Qwen2-VL currently does not support prefix caching"
+        self.config = config
+        self.multimodal_config = multimodal_config
+        self.visual = Qwen2VisionTransformer(
+            config.vision_config,
+            norm_eps=getattr(config, "rms_norm_eps", 1e-6),
+            # NOTE: Qwen2-VL vision encoder does not support any
+            # quantization method now.
+            quant_config=None,
+        )
+        self.model = Qwen2Model(config, cache_config, quant_config)
+        if config.tie_word_embeddings:
+            self.lm_head = self.model.embed_tokens
+        else:
+            self.lm_head = ParallelLMHead(
+                config.vocab_size, config.hidden_size, quant_config=quant_config
+            )
+        self.logits_processor = LogitsProcessor(config.vocab_size)
+        self.sampler = Sampler()
+
+    def _validate_and_reshape_mm_tensor(
+        self, mm_input: Union[torch.Tensor, List[torch.Tensor]], name: str
+    ) -> torch.Tensor:
+        if not isinstance(mm_input, (torch.Tensor, list)):
+            raise ValueError(
+                f"Incorrect type of {name}. " f"Got type: {type(mm_input)}"
+            )
+        if isinstance(mm_input, torch.Tensor):
+            if mm_input.ndim == 2:
+                return mm_input
+            if mm_input.ndim != 3:
+                raise ValueError(
+                    f"{name} should be 2D or batched 3D tensor. "
+                    f"Got ndim: {mm_input.ndim}"
+                )
+            return torch.concat(list(mm_input))
+        else:
+            return torch.concat(mm_input)
+
+    def _parse_and_validate_image_input(
+        self, **kwargs: object
+    ) -> Optional[Qwen2VLImageInputs]:
+        pixel_values = kwargs.pop("pixel_values", None)
+        image_grid_thw = kwargs.pop("image_grid_thw", None)
+        if pixel_values is None:
+            return None
+        pixel_values = self._validate_and_reshape_mm_tensor(
+            pixel_values, "image pixel values"
+        )
+        image_grid_thw = self._validate_and_reshape_mm_tensor(
+            image_grid_thw, "image grid_thw"
+        )
+        if not isinstance(pixel_values, (torch.Tensor, list)):
+            raise ValueError(
+                "Incorrect type of image pixel values. "
+                f"Got type: {type(pixel_values)}"
+            )
+        return Qwen2VLImageInputs(
+            pixel_values=pixel_values, image_grid_thw=image_grid_thw
+        )
+
+    def _parse_and_validate_video_input(
+        self, **kwargs: object
+    ) -> Optional[Qwen2VLVideoInputs]:
+        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
+        video_grid_thw = kwargs.pop("video_grid_thw", None)
+        if pixel_values_videos is None:
+            return None
+        pixel_values_videos = self._validate_and_reshape_mm_tensor(
+            pixel_values_videos, "video pixel values"
+        )
+        video_grid_thw = self._validate_and_reshape_mm_tensor(
+            video_grid_thw, "video grid_thw"
+        )
+        return Qwen2VLVideoInputs(
+            pixel_values_videos=pixel_values_videos,
+            video_grid_thw=video_grid_thw,
+        )
+
+    def _process_image_input(
+        self, image_input: Qwen2VLImageInputs
+    ) -> torch.Tensor:
+        pixel_values = image_input["pixel_values"].type(self.visual.dtype)
+        image_embeds = self.visual(
+            pixel_values, grid_thw=image_input["image_grid_thw"]
+        )
+        return image_embeds
+
+    def _process_video_input(
+        self, video_input: Qwen2VLVideoInputs
+    ) -> torch.Tensor:
+        pixel_values_videos = video_input["pixel_values_videos"].type(
+            self.visual.dtype
+        )
+        video_embeds = self.visual(
+            pixel_values_videos, grid_thw=video_input["video_grid_thw"]
+        )
+        return video_embeds
+
+    def _merge_multimodal_embeddings(
+        self,
+        input_ids: torch.Tensor,
+        inputs_embeds: torch.Tensor,
+        multimodal_embeddings: torch.Tensor,
+        placeholder_token_id: int,
+    ) -> torch.Tensor:
+        mask = input_ids == placeholder_token_id
+        inputs_embeds[mask, :] = multimodal_embeddings
+        return inputs_embeds
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+        **kwargs: object,
+    ) -> SamplerOutput:
+        """Run forward pass for Qwen2-VL.
+        Args:
+            input_ids: Flattened (concatenated) input_ids corresponding to a
+                batch.
+            positions: Flattened (concatenated) position ids corresponding to a
+                batch.
+                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
+                opensource models), the shape will be `(3, seq_len)`,
+                otherwise it will be `(seq_len,).
+            pixel_values: Pixel values to be fed to a model.
+                `None` if no images are passed.
+            image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
+                `None` if no images are passed.
+            pixel_values_videos: Pixel values of videos to be fed to a model.
+                `None` if no videos are passed.
+            video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
+                `None` if no videos are passed.
+        """
+        image_input = self._parse_and_validate_image_input(**kwargs)
+        video_input = self._parse_and_validate_video_input(**kwargs)
+        if image_input is None and video_input is None:
+            inputs_embeds = None
+        else:
+            if (
+                getattr(self.config, "rope_scaling", {}).get("type", None)
+                == "mrope"
+            ):
+                assert positions.ndim == 2 and positions.size(0) == 3, (
+                    "multimodal section rotary embedding requires "
+                    f"(3, seq_len) positions, but got {positions.size()}"
+                )
+            inputs_embeds = self.model.embed_tokens(input_ids)
+            if image_input is not None:
+                image_embeds = self._process_image_input(image_input)
+                inputs_embeds = self._merge_multimodal_embeddings(
+                    input_ids,
+                    inputs_embeds,
+                    image_embeds,
+                    placeholder_token_id=self.config.image_token_id,
+                )
+            if video_input is not None:
+                video_embeds = self._process_video_input(video_input)
+                inputs_embeds = self._merge_multimodal_embeddings(
+                    input_ids,
+                    inputs_embeds,
+                    video_embeds,
+                    placeholder_token_id=self.config.video_token_id,
+                )
+            input_ids = None
+        hidden_states = self.model(
+            input_ids=input_ids,
+            positions=positions,
+            kv_caches=kv_caches,
+            attn_metadata=attn_metadata,
+            inputs_embeds=inputs_embeds,
+        )
+        return hidden_states
+
+    def compute_logits(
+        self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata
+    ) -> torch.Tensor:
+        logits = self.logits_processor(
+            self.lm_head, hidden_states, sampling_metadata
+        )
+        return logits
+
+    def sample(
+        self,
+        logits: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            ("qkv_proj", "q_proj", "q"),
+            ("qkv_proj", "k_proj", "k"),
+            ("qkv_proj", "v_proj", "v"),
+            ("gate_up_proj", "up_proj", 1),
+            ("gate_up_proj", "gate_proj", 0),
+        ]
+        params_dict = dict(self.named_parameters(remove_duplicate=False))
+        for name, loaded_weight in weights:
+            if "rotary_emb.inv_freq" in name:
+                continue
+            if self.config.tie_word_embeddings and "lm_head.weight" in name:
+                continue
+            for param_name, weight_name, shard_id in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                if "visual" in name and "qkv.weight" in name:
+                    visual_num_heads = self.config.vision_config.num_heads
+                    visual_embed_dim = self.config.vision_config.embed_dim
+                    head_size = visual_embed_dim // visual_num_heads
+                    loaded_weight = loaded_weight.view(
+                        3, visual_num_heads, head_size, visual_embed_dim
+                    )
+                    loaded_weight = loaded_weight.transpose(0, 1)
+                    loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
+                elif "visual" in name and "qkv.bias" in name:
+                    visual_num_heads = self.config.vision_config.num_heads
+                    visual_embed_dim = self.config.vision_config.embed_dim
+                    head_size = visual_embed_dim // visual_num_heads
+                    loaded_weight = loaded_weight.view(
+                        3, visual_num_heads, head_size
+                    )
+                    loaded_weight = loaded_weight.transpose(0, 1)
+                    loaded_weight = loaded_weight.reshape(-1)
+                try:
+                    param = params_dict[name]
+                except KeyError:
+                    print(params_dict.keys())
+                    raise
+                weight_loader = getattr(
+                    param, "weight_loader", default_weight_loader
+                )
+                weight_loader(param, loaded_weight)

+ 3 - 5
aphrodite/multimodal/base.py

@@ -76,14 +76,12 @@ class MultiModalInputs(_MultiModalInputsBase):
         if len(inputs_list) == 0:
             return {}
 
-        keys = inputs_list[0].keys()
-
         item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
 
         for inputs in inputs_list:
-            if inputs.keys() != keys:
-                msg = f"Inputs do not share the same keys ({keys})"
-                raise ValueError(msg)
+            # For models that supports multiple modalities (e.g. Qwen2-VL),
+            # different modalities will return different data keys,
+            # so batch() should skip the same key check.
 
             for k, v in inputs.items():
                 item_lists[k].append(v)

+ 4 - 4
aphrodite/transformers_utils/config.py

@@ -17,12 +17,12 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
 
 import aphrodite.common.envs as envs
 from aphrodite.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
-                                                  EAGLEConfig, GraniteConfig,
+                                                  EAGLEConfig,
                                                   InternVLChatConfig,
                                                   JAISConfig, MedusaConfig,
                                                   MLPSpeculatorConfig,
-                                                  MPTConfig, RWConfig,
-                                                  UltravoxConfig)
+                                                  MPTConfig, Qwen2VLConfig,
+                                                  RWConfig, UltravoxConfig)
 from aphrodite.transformers_utils.utils import check_gguf_file
 
 APHRODITE_USE_MODELSCOPE = envs.APHRODITE_USE_MODELSCOPE
@@ -46,7 +46,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
     "internvl_chat": InternVLChatConfig,
     "ultravox": UltravoxConfig,
     "eagle": EAGLEConfig,
-    "granite": GraniteConfig,
+    "qwen2_vl": Qwen2VLConfig,
 }
 
 for name, cls in _CONFIG_REGISTRY.items():

+ 4 - 2
aphrodite/transformers_utils/configs/__init__.py

@@ -5,13 +5,14 @@ from aphrodite.transformers_utils.configs.eagle import EAGLEConfig
 # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
 # `FalconConfig` class from the official HuggingFace transformers library.
 from aphrodite.transformers_utils.configs.falcon import RWConfig
-from aphrodite.transformers_utils.configs.granite import GraniteConfig
 from aphrodite.transformers_utils.configs.internvl import InternVLChatConfig
 from aphrodite.transformers_utils.configs.jais import JAISConfig
 from aphrodite.transformers_utils.configs.medusa import MedusaConfig
 from aphrodite.transformers_utils.configs.mlp_speculator import (
     MLPSpeculatorConfig)
 from aphrodite.transformers_utils.configs.mpt import MPTConfig
+from aphrodite.transformers_utils.configs.qwen2vl import (Qwen2VLConfig,
+                                                          Qwen2VLVisionConfig)
 from aphrodite.transformers_utils.configs.ultravox import UltravoxConfig
 
 __all__ = [
@@ -25,5 +26,6 @@ __all__ = [
     "MedusaConfig",
     "UltravoxConfig",
     "EAGLEConfig",
-    "GraniteConfig",
+    "Qwen2VLConfig",
+    "Qwen2VLVisionConfig",
 ]

+ 0 - 186
aphrodite/transformers_utils/configs/granite.py

@@ -1,186 +0,0 @@
-# coding=utf-8
-# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
-#
-# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
-# and OPT implementations in this library. It has been modified from its
-# original forms to accommodate minor architectural differences compared
-# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Granite model configuration"""
-from transformers.configuration_utils import PretrainedConfig
-from transformers.modeling_rope_utils import rope_config_validation
-from transformers.utils import logging
-
-logger = logging.get_logger(__name__)
-
-
-class GraniteConfig(PretrainedConfig):
-    r"""
-    This is the configuration class to store the configuration of
-    a [`GraniteModel`]. It is used to instantiate an Granite
-    model according to the specified arguments, defining the model architecture.
-    Instantiating a configuration with the defaults will yield a similar
-    configuration to that of the Granite-3B.
-    Configuration objects inherit from [`PretrainedConfig`] and can be used to
-    control the model outputs. Read the documentation from [`PretrainedConfig`]
-    for more information.
-    Args:
-        vocab_size (`int`, *optional*, defaults to 32000):
-            Vocabulary size of the Granite model. Defines the number of
-            different tokens that can be represented by the `inputs_ids`
-            passed when calling [`GraniteModel`]
-        hidden_size (`int`, *optional*, defaults to 4096):
-            Dimension of the hidden representations.
-        intermediate_size (`int`, *optional*, defaults to 11008):
-            Dimension of the MLP representations.
-        num_hidden_layers (`int`, *optional*, defaults to 32):
-            Number of hidden layers in the Transformer decoder.
-        num_attention_heads (`int`, *optional*, defaults to 32):
-            Number of attention heads for each attention layer in the
-            Transformer decoder.
-        num_key_value_heads (`int`, *optional*):
-            This is the number of key_value heads that should be used to
-            implement Grouped Query Attention. If
-            `num_key_value_heads=num_attention_heads`, the model will use Multi
-            Head Attention (MHA), if `num_key_value_heads=1` the model will use
-            Multi Query Attention (MQA) otherwise GQA is used. When converting
-            a multi-head checkpoint to a GQA checkpoint, each group key and
-            value head should be constructed by meanpooling all the original
-            heads within that group. For more details checkout
-            [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not
-            specified, will default to `num_attention_heads`.
-        hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
-            The non-linear activation function (function or string) in the
-            decoder.
-        max_position_embeddings (`int`, *optional*, defaults to 2048):
-            The maximum sequence length that this model might ever be used with.
-        initializer_range (`float`, *optional*, defaults to 0.02):
-            The standard deviation of the truncated_normal_initializer for
-            initializing all weight matrices.
-        rms_norm_eps (`float`, *optional*, defaults to 1e-06):
-            The epsilon used by the rms normalization layers.
-        use_cache (`bool`, *optional*, defaults to `True`):
-            Whether or not the model should return the last key/values
-            attentions (not used by all models). Only relevant if
-            `config.is_decoder=True`.
-        pad_token_id (`int`, *optional*):
-            Padding token id.
-        bos_token_id (`int`, *optional*, defaults to 1):
-            Beginning of stream token id.
-        eos_token_id (`int`, *optional*, defaults to 2):
-            End of stream token id.
-        tie_word_embeddings (`bool`, *optional*, defaults to `False`):
-            Whether to tie weight embeddings
-        rope_theta (`float`, *optional*, defaults to 10000.0):
-            The base period of the RoPE embeddings.
-        rope_scaling (`Dict`, *optional*):
-            Dictionary containing the scaling configuration for the RoPE
-            embeddings. Currently supports two scaling strategies: linear and
-            dynamic. Their scaling factor must be a float greater than 1. The
-            expected format is
-            `{"type": strategy name, "factor": scaling factor}`.
-            When using this flag, don't update `max_position_embeddings` to
-            the expected new maximum. See the following thread for more
-            information on how these scaling strategies behave:
-            https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/.
-            This is an experimental feature, subject to breaking API changes
-            in future versions.
-        attention_bias (`bool`, *optional*, defaults to `False`):
-            Whether to use a bias in the query, key, value and output
-            projection layers during self-attention.
-        attention_dropout (`float`, *optional*, defaults to 0.0):
-            The dropout ratio for the attention probabilities.
-        mlp_bias (`bool`, *optional*, defaults to `False`):
-            Whether to use a bias in up_proj, down_proj and gate_proj layers
-            in the MLP layers.
-        embedding_multiplier (`float`, *optional*, defaults to 1.0):
-            embedding multiplier
-        logits_scaling (`float`, *optional*, defaults to 1.0):
-            divisor for output logits
-        residual_multiplier (`float`, *optional*, defaults to 1.0):
-            residual multiplier
-        attention_multiplier (`float`, *optional*, defaults to 1.0):
-            attention multiplier
-    ```python
-    >>> from transformers import GraniteModel, GraniteConfig
-    >>> # Initializing a Granite granite-3b style configuration
-    >>> configuration = GraniteConfig()
-    >>> # Initializing a model from the granite-7b style configuration
-    >>> model = GraniteModel(configuration)
-    >>> # Accessing the model configuration
-    >>> configuration = model.config
-    ```"""
-
-    model_type = "granite"
-    keys_to_ignore_at_inference = ["past_key_values"]
-
-    def __init__(
-        self,
-        vocab_size=32000,
-        hidden_size=4096,
-        intermediate_size=11008,
-        num_hidden_layers=32,
-        num_attention_heads=32,
-        num_key_value_heads=None,
-        hidden_act="silu",
-        max_position_embeddings=2048,
-        initializer_range=0.02,
-        rms_norm_eps=1e-6,
-        use_cache=True,
-        pad_token_id=None,
-        bos_token_id=1,
-        eos_token_id=2,
-        tie_word_embeddings=False,
-        rope_theta=10000.0,
-        rope_scaling=None,
-        attention_bias=False,
-        attention_dropout=0.0,
-        mlp_bias=False,
-        embedding_multiplier=1.0,
-        logits_scaling=1.0,
-        residual_multiplier=1.0,
-        attention_multiplier=1.0,
-        **kwargs,
-    ):
-        self.vocab_size = vocab_size
-        self.max_position_embeddings = max_position_embeddings
-        self.hidden_size = hidden_size
-        self.intermediate_size = intermediate_size
-        self.num_hidden_layers = num_hidden_layers
-        self.num_attention_heads = num_attention_heads
-        # for backward compatibility
-        if num_key_value_heads is None:
-            num_key_value_heads = num_attention_heads
-        self.num_key_value_heads = num_key_value_heads
-        self.hidden_act = hidden_act
-        self.initializer_range = initializer_range
-        self.rms_norm_eps = rms_norm_eps
-        self.use_cache = use_cache
-        self.rope_theta = rope_theta
-        self.rope_scaling = rope_scaling
-        self.attention_bias = attention_bias
-        self.attention_dropout = attention_dropout
-        self.mlp_bias = mlp_bias
-        self.embedding_multiplier = embedding_multiplier
-        self.logits_scaling = logits_scaling
-        self.residual_multiplier = residual_multiplier
-        self.attention_multiplier = attention_multiplier
-        super().__init__(
-            pad_token_id=pad_token_id,
-            bos_token_id=bos_token_id,
-            eos_token_id=eos_token_id,
-            tie_word_embeddings=tie_word_embeddings,
-            **kwargs,
-        )
-        rope_config_validation(self)

+ 131 - 0
aphrodite/transformers_utils/configs/qwen2vl.py

@@ -0,0 +1,131 @@
+# coding=utf-8
+# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Qwen2VL model configuration"""
+
+import os
+from typing import Union
+
+from transformers import PretrainedConfig
+
+
+class Qwen2VLVisionConfig(PretrainedConfig):
+    model_type = "qwen2_vl"
+
+    def __init__(
+        self,
+        depth=32,
+        embed_dim=1280,
+        hidden_size=3584,
+        hidden_act="quick_gelu",
+        mlp_ratio=4,
+        num_heads=16,
+        in_channels=3,
+        patch_size=14,
+        spatial_merge_size=2,
+        temporal_patch_size=2,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.depth = depth
+        self.embed_dim = embed_dim
+        self.hidden_size = hidden_size
+        self.hidden_act = hidden_act
+        self.mlp_ratio = mlp_ratio
+        self.num_heads = num_heads
+        self.in_channels = in_channels
+        self.patch_size = patch_size
+        self.spatial_merge_size = spatial_merge_size
+        self.temporal_patch_size = temporal_patch_size
+
+    @classmethod
+    def from_pretrained(
+        cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
+    ) -> "PretrainedConfig":
+        cls._set_token_in_kwargs(kwargs)
+
+        config_dict, kwargs = cls.get_config_dict(
+            pretrained_model_name_or_path, **kwargs
+        )
+
+        if config_dict.get("model_type") == "qwen2_vl":
+            config_dict = config_dict["vision_config"]
+
+        return cls.from_dict(config_dict, **kwargs)
+
+
+class Qwen2VLConfig(PretrainedConfig):
+    def __init__(
+        self,
+        vocab_size=152064,
+        hidden_size=8192,
+        intermediate_size=29568,
+        num_hidden_layers=80,
+        num_attention_heads=64,
+        num_key_value_heads=8,
+        hidden_act="silu",
+        max_position_embeddings=32768,
+        initializer_range=0.02,
+        rms_norm_eps=1e-05,
+        use_cache=True,
+        tie_word_embeddings=False,
+        rope_theta=1000000.0,
+        use_sliding_window=False,
+        sliding_window=4096,
+        max_window_layers=80,
+        attention_dropout=0.0,
+        vision_config=None,
+        rope_scaling=None,
+        **kwargs,
+    ):
+        if isinstance(vision_config, dict):
+            self.vision_config = Qwen2VLVisionConfig(**vision_config)
+        elif vision_config is None:
+            self.vision_config = Qwen2VLVisionConfig()
+
+        self.vocab_size = vocab_size
+        self.max_position_embeddings = max_position_embeddings
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.use_sliding_window = use_sliding_window
+        self.sliding_window = sliding_window
+        self.max_window_layers = max_window_layers
+
+        # for backward compatibility
+        if num_key_value_heads is None:
+            num_key_value_heads = num_attention_heads
+
+        self.num_key_value_heads = num_key_value_heads
+        self.hidden_act = hidden_act
+        self.initializer_range = initializer_range
+        self.rms_norm_eps = rms_norm_eps
+        self.use_cache = use_cache
+        self.rope_theta = rope_theta
+        self.attention_dropout = attention_dropout
+        self.rope_scaling = rope_scaling
+
+        # NOTE: the following section from original transformers config
+        # for Qwen2-VL is commented out to address rope config loading issue
+        #
+        # if self.rope_scaling is not None and "type" in self.rope_scaling:
+        #     if self.rope_scaling["type"] == "mrope":
+        #         self.rope_scaling["type"] = "default"
+        #     self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+        # rope_config_validation(self)
+
+        super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

+ 35 - 0
aphrodite/transformers_utils/processor.py

@@ -0,0 +1,35 @@
+from typing import cast
+
+
+def get_processor(
+    processor_name: str,
+    *args,
+    trust_remote_code: bool = False,
+    **kwargs,
+):
+    """Gets a processor for the given model name via HuggingFace."""
+    # don't put this import at the top level
+    # it will call torch.cuda.device_count()
+    from transformers import AutoProcessor
+    from transformers.processing_utils import ProcessorMixin
+
+    try:
+        processor = AutoProcessor.from_pretrained(
+            processor_name, *args, trust_remote_code=trust_remote_code, **kwargs
+        )
+    except ValueError as e:
+        # If the error pertains to the processor class not existing or not
+        # currently being imported, suggest using the --trust-remote-code flag.
+        # Unlike AutoTokenizer, AutoProcessor does not separate such errors
+        if not trust_remote_code:
+            err_msg = (
+                "Failed to load the processor. If the processor is "
+                "a custom processor not yet available in the HuggingFace "
+                "transformers library, consider setting "
+                "`trust_remote_code=True` in LLM or using the "
+                "`--trust-remote-code` flag in the CLI."
+            )
+            raise RuntimeError(err_msg) from e
+        else:
+            raise e
+    return cast(ProcessorMixin, processor)

+ 94 - 9
aphrodite/worker/model_runner.py

@@ -36,6 +36,7 @@ from aphrodite.inputs import INPUT_REGISTRY, InputRegistry
 from aphrodite.lora.layers import LoRAMapping
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
+from aphrodite.modeling.layers.rotary_embedding import MRotaryEmbedding
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader import get_model
 from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
@@ -183,6 +184,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
         def simple_reinit(self):
             self.input_tokens[0].clear()  # type: ignore
             self.input_positions[0].clear()  # type: ignore
+            self.mrope_input_positions = None  # type: ignore
             self.seq_lens[0] = 0  # type: ignore
             self.orig_seq_lens[0] = 0  # type: ignore
             self.query_lens[0] = 0  # type: ignore
@@ -208,6 +210,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
             # Input tokens and positions.
             input_tokens: Optional[List[List[int]]] = None,
             input_positions: Optional[List[List[int]]] = None,
+            mrope_input_positions: Optional[List[List[List[int]]]] = None,
 
             # The sequence length (may be capped to the sliding window).
             seq_lens: Optional[List[int]] = None,
@@ -267,6 +270,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
                         for seq_id in range(len(self.seq_ids)):
                             self.input_positions[seq_id].clear()
 
+                    self.mrope_input_positions = None
                     if seq_lens:
                         self.seq_lens = seq_lens
                     else:
@@ -328,6 +332,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
             else:
                 self.input_tokens = input_tokens or []
                 self.input_positions = input_positions or []
+                self.mrope_input_positions = mrope_input_positions or None
                 self.seq_lens = seq_lens or []
                 self.orig_seq_lens = orig_seq_lens or []
                 self.query_lens = query_lens or []
@@ -358,6 +363,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
 
             self.input_tokens = [[] for _ in range(self.n_seqs)]
             self.input_positions = [[] for _ in range(self.n_seqs)]
+            self.mrope_input_positions = None
             self.seq_lens = [0] * self.n_seqs
             self.orig_seq_lens = [0] * self.n_seqs
             self.query_lens = [0] * self.n_seqs
@@ -493,6 +499,16 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
         inter_data.query_lens[
             seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
 
+        if seq_data.mrope_position_delta is not None:
+            if inter_data.mrope_input_positions is None:
+                inter_data.mrope_input_positions = [None] * inter_data.n_seqs
+            inter_data.mrope_input_positions[
+                seq_idx] = MRotaryEmbedding.get_next_input_positions(
+                    seq_data.mrope_position_delta,
+                    context_len,
+                    seq_len,
+                )
+
     def _compute_for_prefix_cache_hit(
             self, inter_data: InterDataForSeqGroup, seq_idx: int,
             seq_group_metadata: SequenceGroupMetadata):
@@ -638,6 +654,36 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
         mm_kwargs = self.multi_modal_input_mapper(mm_data)
         inter_data.multi_modal_inputs = mm_kwargs
 
+        # special processing for mrope position deltas.
+        if self.runner.model_is_mrope:
+            image_grid_thw = mm_kwargs.get("image_grid_thw", None)
+            video_grid_thw = mm_kwargs.get("video_grid_thw", None)
+            assert image_grid_thw is not None or video_grid_thw is not None, (
+                "mrope embedding type requires multi-modal input mapper "
+                "returns 'image_grid_thw' or 'video_grid_thw'.")
+            hf_config = self.runner.model_config.hf_config
+            inter_data.mrope_input_positions = [None] * inter_data.n_seqs
+            for seq_idx in range(inter_data.n_seqs):
+                seq_data = seq_group_metadata.seq_data[
+                    inter_data.seq_ids[seq_idx]]
+                token_ids = seq_data.get_token_ids()
+                mrope_input_positions, mrope_position_delta = \
+                    MRotaryEmbedding.get_input_positions(
+                        token_ids,
+                        image_grid_thw=image_grid_thw,
+                        video_grid_thw=video_grid_thw,
+                        image_token_id=hf_config.image_token_id,
+                        video_token_id=hf_config.video_token_id,
+                        vision_start_token_id=hf_config.vision_start_token_id,
+                        vision_end_token_id=hf_config.vision_end_token_id,
+                        spatial_merge_size=hf_config.vision_config.
+                        spatial_merge_size,
+                        context_len=inter_data.context_lens[seq_idx],
+                    )
+                seq_data.mrope_position_delta = mrope_position_delta
+                inter_data.mrope_input_positions[
+                    seq_idx] = mrope_input_positions
+
     def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
         """Add a sequence group to the builder."""
         seq_ids = seq_group_metadata.seq_data.keys()
@@ -684,10 +730,28 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
             # prefix caching and there is no decode request.
             return self.model_input_cls()
 
-        input_positions = []
-        for inter_data in self.inter_data_list:
-            for cur_input_positions in inter_data.input_positions:
-                input_positions.extend(cur_input_positions)
+        mrope_input_positions: Optional[List[List[int]]] = None
+        if any(inter_data.mrope_input_positions is not None
+               for inter_data in self.inter_data_list):
+            mrope_input_positions = [[] for _ in range(3)]
+            for idx in range(3):
+                for inter_data in self.inter_data_list:
+                    msections = inter_data.mrope_input_positions
+                    if msections is None:
+                        for _seq_input_positions in inter_data.input_positions:
+                            mrope_input_positions[idx].extend(
+                                _seq_input_positions)
+                    else:
+                        for _seq_mrope_input_positions in msections:
+                            mrope_input_positions[idx].extend(
+                                _seq_mrope_input_positions[idx])
+            input_positions = None
+        else:
+            input_positions = []
+            for inter_data in self.inter_data_list:
+                for cur_input_positions in inter_data.input_positions:
+                    input_positions.extend(cur_input_positions)
+
         seq_lens = []
         max_decode_seq_len = 0
         for inter_data in self.inter_data_list:
@@ -722,14 +786,24 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
         # Tokens and positions.
         if cuda_graph_pad_size:
             input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
-            input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
         assert self.runner.device is not None
         input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
                                                self.runner.device,
                                                self.runner.pin_memory)
-        input_positions_tensor = async_tensor_h2d(input_positions, torch.long,
-                                                  self.runner.device,
-                                                  self.runner.pin_memory)
+        if mrope_input_positions is not None:
+            for idx in range(3):
+                mrope_input_positions[idx].extend(
+                    itertools.repeat(0, cuda_graph_pad_size))
+            input_positions_tensor = async_tensor_h2d(mrope_input_positions,
+                                                      torch.long,
+                                                      self.runner.device,
+                                                      self.runner.pin_memory)
+        else:
+            input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
+            input_positions_tensor = async_tensor_h2d(input_positions,
+                                                      torch.long,
+                                                      self.runner.device,
+                                                      self.runner.pin_memory)
 
         # Sequence and query lengths.
         if cuda_graph_pad_size:
@@ -1249,6 +1323,15 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
             raise RuntimeError("PromptAdapter is not enabled.")
         return self.prompt_adapter_manager.list_adapters()
 
+    @property
+    def model_is_mrope(self) -> bool:
+        """Detect if the model has "mrope" rope_scaling type.
+        mrope requires keep "rope_deltas" between prompt and decoding phases."""
+        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
+        if rope_scaling is None:
+            return False
+        return rope_scaling.get("type", None) == "mrope"
+
     @torch.inference_mode()
     def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
         """Cuda graph capture a model.
@@ -1283,6 +1366,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
         max_batch_size = self.max_batchsize_to_capture
         input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
         input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
+        if self.model_is_mrope:
+            input_positions = torch.tile(input_positions, (3, 1))
         # Prepare dummy previous_hidden_states only if needed by the model.
         # This is used by draft models such as EAGLE.
         previous_hidden_states = None
@@ -1348,7 +1433,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
                         "input_ids":
                         input_tokens[:batch_size],
                         "positions":
-                        input_positions[:batch_size],
+                        input_positions[..., :batch_size],
                         "hidden_or_intermediate_states":
                         hidden_or_intermediate_states[
                             virtual_engine]  # type: ignore

BIN
examples/vision/nadeko.mp4


+ 174 - 39
examples/vision/vision_example.py

@@ -6,17 +6,54 @@ on HuggingFace model repository.
 """
 import os
 
+import cv2
+import numpy as np
 from PIL import Image
 from transformers import AutoTokenizer
 
 from aphrodite import LLM, SamplingParams
+from aphrodite.assets.video import VideoAsset
 from aphrodite.common.utils import FlexibleArgumentParser
+from aphrodite.multimodal.utils import sample_frames_from_video
 
 # Input image and question
 image_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                           "burg.jpg")
 image = Image.open(image_path).convert("RGB")
-question = "What is the content of this image?"
+img_question = "What is the content of this image?"
+
+# Input video and question
+video_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
+                          "nadeko.mp4")
+vid_question = "What's in this video?"
+
+
+def load_video_frames(video_path: str, num_frames: int) -> np.ndarray:
+    """
+    Load video frames from a local file path.
+
+    Args:
+        video_path: Path to the video file
+        num_frames: Number of frames to sample from the video
+
+    Returns:
+        np.ndarray: Array of sampled video frames
+    """
+    cap = cv2.VideoCapture(video_path)
+    if not cap.isOpened():
+        raise ValueError(f"Could not open video file {video_path}")
+
+    frames = []
+    while True:
+        ret, frame = cap.read()
+        if not ret:
+            break
+        frames.append(frame)
+    cap.release()
+
+    frames = np.stack(frames)
+    return sample_frames_from_video(frames, num_frames)
+
 
 
 # LLaVA-1.5
@@ -25,17 +62,26 @@ def run_llava(question):
     prompt = f"USER: <image>\n{question}\nASSISTANT:"
 
     llm = LLM(model="llava-hf/llava-1.5-7b-hf")
-
-    return llm, prompt
+    stop_token_ids = None
+    return llm, prompt, stop_token_ids
 
 
 # LLaVA-1.6/LLaVA-NeXT
 def run_llava_next(question):
 
     prompt = f"[INST] <image>\n{question} [/INST]"
-    llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf")
+    llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192)
+    stop_token_ids = None
+    return llm, prompt, stop_token_ids
+
 
-    return llm, prompt
+# LlaVA-NeXT-Video
+# Currently only support for video input
+def run_llava_next_video(question):
+    prompt = f"USER: <video>\n{question} ASSISTANT:"
+    llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf")
+    stop_token_ids = None
+    return llm, prompt, stop_token_ids
 
 
 # Fuyu
@@ -43,8 +89,8 @@ def run_fuyu(question):
 
     prompt = f"{question}\n"
     llm = LLM(model="adept/fuyu-8b")
-
-    return llm, prompt
+    stop_token_ids = None
+    return llm, prompt, stop_token_ids
 
 
 # Phi-3-Vision
@@ -58,20 +104,22 @@ def run_phi3v(question):
     # In this example, we override max_num_seqs to 5 while
     # keeping the original context length of 128k.
     llm = LLM(
-        model="microsoft/Phi-3.5-vision-instruct",
+        model="microsoft/Phi-3-vision-128k-instruct",
         trust_remote_code=True,
         max_num_seqs=5,
     )
-    return llm, prompt
+    stop_token_ids = None
+    return llm, prompt, stop_token_ids
 
 
 # PaliGemma
 def run_paligemma(question):
 
+    # PaliGemma has special prompt format for VQA
     prompt = "caption en"
-    llm = LLM(model="google/paligemma2-3b-ft-docci-448")
-
-    return llm, prompt
+    llm = LLM(model="google/paligemma-3b-mix-224")
+    stop_token_ids = None
+    return llm, prompt, stop_token_ids
 
 
 # Chameleon
@@ -79,7 +127,8 @@ def run_chameleon(question):
 
     prompt = f"{question}<image>"
     llm = LLM(model="facebook/chameleon-7b")
-    return llm, prompt
+    stop_token_ids = None
+    return llm, prompt, stop_token_ids
 
 
 # MiniCPM-V
@@ -90,13 +139,26 @@ def run_minicpmv(question):
     # model_name = "HwwwH/MiniCPM-V-2"
 
     # 2.5
-    model_name = "openbmb/MiniCPM-Llama3-V-2_5"
+    # model_name = "openbmb/MiniCPM-Llama3-V-2_5"
+
+    #2.6
+    model_name = "openbmb/MiniCPM-V-2_6"
     tokenizer = AutoTokenizer.from_pretrained(model_name,
                                               trust_remote_code=True)
     llm = LLM(
         model=model_name,
         trust_remote_code=True,
     )
+    # NOTE The stop_token_ids are different for various versions of MiniCPM-V
+    # 2.0
+    # stop_token_ids = [tokenizer.eos_id]
+
+    # 2.5
+    # stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
+
+    # 2.6
+    stop_tokens = ['<|im_end|>', '<|endoftext|>']
+    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
 
     messages = [{
         'role': 'user',
@@ -105,7 +167,33 @@ def run_minicpmv(question):
     prompt = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=True)
-    return llm, prompt
+    return llm, prompt, stop_token_ids
+
+
+# InternVL
+def run_internvl(question):
+    model_name = "OpenGVLab/InternVL2-2B"
+
+    llm = LLM(
+        model=model_name,
+        trust_remote_code=True,
+        max_num_seqs=5,
+    )
+
+    tokenizer = AutoTokenizer.from_pretrained(model_name,
+                                              trust_remote_code=True)
+    messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
+    prompt = tokenizer.apply_chat_template(messages,
+                                           tokenize=False,
+                                           add_generation_prompt=True)
+
+    # Stop tokens for InternVL
+    # models variants may have different stop tokens
+    # please refer to the model card for the correct "stop words":
+    # https://huggingface.co/OpenGVLab/InternVL2-2B#service
+    stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
+    stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
+    return llm, prompt, stop_token_ids
 
 
 # BLIP-2
@@ -115,39 +203,45 @@ def run_blip2(question):
     # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
     prompt = f"Question: {question} Answer:"
     llm = LLM(model="Salesforce/blip2-opt-2.7b")
-    return llm, prompt
-
-
-# InternVL
-def run_internvl(question):
-    # Generally, InternVL can use chatml template for conversation
-    TEMPLATE = "<|im_start|>User\n{prompt}<|im_end|>\n<|im_start|>Assistant\n"
-    prompt = f"<image>\n{question}\n"
-    prompt = TEMPLATE.format(prompt=prompt)
-    llm = LLM(
-        model="OpenGVLab/InternVL2-4B",
-        trust_remote_code=True,
-        max_num_seqs=28,
-        tensor_parallel_size=2,
-        max_model_len=8192,
-    )
-    return llm, prompt
+    stop_token_ids = None
+    return llm, prompt, stop_token_ids
 
 
 # Qwen
 def run_qwen_vl(question):
+
     llm = LLM(
         model="Qwen/Qwen-VL",
         trust_remote_code=True,
         max_num_seqs=5,
     )
+
     prompt = f"{question}Picture 1: <img></img>\n"
-    return llm, prompt
+    stop_token_ids = None
+    return llm, prompt, stop_token_ids
+
+
+# Qwen2-VL
+def run_qwen2_vl(question):
+    model_name = "Qwen/Qwen2-VL-7B-Instruct"
+
+    llm = LLM(
+        model=model_name,
+        max_num_seqs=5,
+    )
+
+    prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
+              "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
+              f"{question}<|im_end|>\n"
+              "<|im_start|>assistant\n")
+    stop_token_ids = None
+    return llm, prompt, stop_token_ids
 
 
 model_example_map = {
     "llava": run_llava,
     "llava-next": run_llava_next,
+    "llava-next-video": run_llava_next_video,
     "fuyu": run_fuyu,
     "phi3_v": run_phi3v,
     "paligemma": run_paligemma,
@@ -156,19 +250,53 @@ model_example_map = {
     "blip-2": run_blip2,
     "internvl_chat": run_internvl,
     "qwen_vl": run_qwen_vl,
+    "qwen2_vl": run_qwen2_vl,
 }
 
 
+def get_multi_modal_input(args):
+    """
+    return {
+        "data": image or video,
+        "question": question,
+    }
+    """
+    if args.modality == "image":
+        return {
+            "data": image,
+            "question": img_question,
+        }
+
+    if args.modality == "video":
+        video = VideoAsset(name="nadeko.mp4",
+                          num_frames=args.num_frames,
+                          local_path=video_path).np_ndarrays
+        return {
+            "data": video,
+            "question": vid_question,
+        }
+
+    msg = f"Modality {args.modality} is not supported."
+    raise ValueError(msg)
+
+
 def main(args):
     model = args.model_type
     if model not in model_example_map:
         raise ValueError(f"Model type {model} is not supported.")
 
-    llm, prompt = model_example_map[model](question)
+    modality = args.modality
+    mm_input = get_multi_modal_input(args)
+    data = mm_input["data"]
+    question = mm_input["question"]
+
+    llm, prompt, stop_token_ids = model_example_map[model](question)
 
     # We set temperature to 0.2 so that outputs can be different
     # even when all prompts are identical when running batch inference.
-    sampling_params = SamplingParams(temperature=0.2, max_tokens=128)
+    sampling_params = SamplingParams(temperature=0.2,
+                                     max_tokens=512,
+                                     stop_token_ids=stop_token_ids)
 
     assert args.num_prompts > 0
     if args.num_prompts == 1:
@@ -176,7 +304,7 @@ def main(args):
         inputs = {
             "prompt": prompt,
             "multi_modal_data": {
-                "image": image
+                modality: data
             },
         }
 
@@ -185,7 +313,7 @@ def main(args):
         inputs = [{
             "prompt": prompt,
             "multi_modal_data": {
-                "image": image
+                modality: data
             },
         } for _ in range(args.num_prompts)]
 
@@ -198,7 +326,7 @@ def main(args):
 
 if __name__ == "__main__":
     parser = FlexibleArgumentParser(
-        description='Demo on using Aphrodite for offline inference with '
+        description='Demo on using vLLM for offline inference with '
         'vision language models')
     parser.add_argument('--model-type',
                         '-m',
@@ -210,6 +338,13 @@ if __name__ == "__main__":
                         type=int,
                         default=1,
                         help='Number of prompts to run.')
-
+    parser.add_argument('--modality',
+                        type=str,
+                        default="image",
+                        help='Modality of the input.')
+    parser.add_argument('--num-frames',
+                        type=int,
+                        default=16,
+                        help='Number of frames to extract from the video.')
     args = parser.parse_args()
     main(args)

+ 1 - 0
requirements-common.txt

@@ -35,3 +35,4 @@ msgspec
 python-multipart
 partial-json-parser
 opencv-python
+einops