瀏覽代碼

vlm: add support for molmo vision model (#1069)

AlpinDale 2 月之前
父節點
當前提交
acc0c727c8

+ 2 - 0
aphrodite/endpoints/chat_utils.py

@@ -161,6 +161,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
                 return "<image>"
             if model_type == "qwen2_vl":
                 return "<|vision_start|><|image_pad|><|vision_end|>"
+            if model_type == "molmo":
+                return ""
 
             raise TypeError(f"Unknown model type: {model_type}")
         elif modality == "audio":

+ 1 - 0
aphrodite/modeling/models/__init__.py

@@ -98,6 +98,7 @@ _MULTIMODAL_MODELS = {
                                         "Qwen2VLForConditionalGeneration"),
     "PixtralForConditionalGeneration": ("pixtral",
                                         "PixtralForConditionalGeneration"),
+    "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
 }
 
 _CONDITIONAL_GENERATION_MODELS = {

+ 1294 - 0
aphrodite/modeling/models/molmo.py

@@ -0,0 +1,1294 @@
+import logging
+import math
+import re
+from array import array
+from dataclasses import dataclass
+from functools import lru_cache, partial
+from typing import (Any, Iterable, List, Mapping, Optional, Tuple, TypedDict,
+                    Union)
+
+import torch
+from einops import rearrange
+from PIL import Image
+from torch import nn
+from torch.nn import functional as F
+from transformers import PretrainedConfig
+
+import aphrodite.common.envs as envs
+from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.attention.selector import (_Backend, backend_name_to_enum,
+                                          get_global_forced_attn_backend)
+from aphrodite.common.config import CacheConfig, MultiModalConfig
+from aphrodite.common.logger import log_once
+from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
+                                       IntermediateTensors, SequenceData)
+from aphrodite.distributed import (get_pp_group,
+                                   get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size,
+                                   split_tensor_along_last_dim,
+                                   tensor_model_parallel_all_gather)
+from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from aphrodite.modeling.layers.activation import QuickGELU, SiluAndMul
+from aphrodite.modeling.layers.layernorm import RMSNorm
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.models.interfaces import SupportsMultiModal
+from aphrodite.modeling.models.utils import make_layers
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
+from aphrodite.platforms import current_platform
+from aphrodite.quantization.base_config import QuantizationConfig
+from aphrodite.transformers_utils.processor import get_processor
+
+log = logging.getLogger(__name__)
+
+# TODO: hard-coded for now. Consider making it configurable.
+VIT_LAYERS = [-2, -9]
+NUM_PREFIX_TOKENS = 1
+ADDITIONAL_VOCAB_SIZE = 128
+
+
+class MolmoImageInputs(TypedDict):
+    images: torch.Tensor
+    """Shape:
+    `(batch_size, num_crops, num_patch, patch_dim)`
+    """
+
+    image_input_idx: torch.Tensor
+    """Shape:
+    `(batch_size, num_crops, num_patch)`
+    """
+
+    seq_len: torch.Tensor
+    """Shape:
+    `(batch_size, )`
+    """
+
+    image_masks: Optional[torch.Tensor]
+    """Shape:
+    `(batch_size, num_crops, num_patch)`
+    """
+
+
+@dataclass
+class VisionBackboneConfig:
+    image_default_input_size: Tuple[int, int] = (336, 336)
+    image_patch_size: int = 14
+    image_pos_patch_size: int = 14
+    image_emb_dim: int = 1024
+    image_num_heads: int = 16
+    image_num_key_value_heads: int = 16
+    image_num_layers: int = 23
+    image_mlp_dim: int = 4096
+    image_mlp_activations: str = "quick_gelu"
+    image_num_pos: int = 577
+    image_norm_eps: float = 1e-5
+
+    def __post_init__(self):
+        self.image_default_input_size = tuple(
+            self.image_default_input_size)  # type: ignore[assignment]
+
+    @property
+    def image_num_patch(self):
+        h, w = self.image_default_input_size
+        return h // self.image_patch_size, w // self.image_patch_size
+
+
+class ViTMLP(nn.Module):
+    """MLP used in Vision Transformer."""
+
+    def __init__(
+        self,
+        config: VisionBackboneConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+    ):
+        super().__init__()
+        self.w1 = ColumnParallelLinear(
+            config.image_emb_dim,
+            config.image_mlp_dim,
+            bias=True,
+            quant_config=quant_config,
+        )
+        # Activation function.
+        assert config.image_mlp_activations == "quick_gelu"
+        self.act = QuickGELU()
+        self.w2 = RowParallelLinear(
+            config.image_mlp_dim,
+            config.image_emb_dim,
+            bias=True,
+            quant_config=quant_config,
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x, _ = self.w1(x)
+        x = self.act(x)
+        x, _ = self.w2(x)
+        return x
+
+
+class MultiHeadDotProductAttention(nn.Module):
+    """Multi-head attention used in Vision Transformer."""
+
+    def __init__(
+        self,
+        config: VisionBackboneConfig,
+        use_bias: bool = True,
+        nlayers: int = 1,
+        quant_config: Optional[QuantizationConfig] = None,
+    ):
+        super().__init__()
+
+        self.hidden_size = config.image_emb_dim
+        self.total_num_heads = config.image_num_heads
+        tp_size = get_tensor_model_parallel_world_size()
+
+        assert self.hidden_size % self.total_num_heads == 0
+        assert self.total_num_heads % tp_size == 0
+
+        self.num_heads = self.total_num_heads // tp_size
+        self.head_dim = self.hidden_size // self.total_num_heads
+
+        self.total_num_kv_heads = config.image_num_key_value_heads
+        if self.total_num_kv_heads >= tp_size:
+            assert self.total_num_kv_heads % tp_size == 0
+        else:
+            assert tp_size % self.total_num_kv_heads == 0
+
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+
+        self.wq = ColumnParallelLinear(
+            nlayers * self.hidden_size,
+            self.total_num_heads * self.head_dim,
+            bias=use_bias,
+            quant_config=quant_config,
+        )
+        self.wk = ColumnParallelLinear(
+            nlayers * self.hidden_size,
+            self.total_num_kv_heads * self.head_dim,
+            bias=use_bias,
+            quant_config=quant_config,
+        )
+        self.wv = ColumnParallelLinear(
+            nlayers * self.hidden_size,
+            self.total_num_kv_heads * self.head_dim,
+            bias=use_bias,
+            quant_config=quant_config,
+        )
+        self.wo = RowParallelLinear(
+            self.total_num_heads * self.head_dim,
+            self.hidden_size,
+            bias=use_bias,
+            quant_config=quant_config,
+        )
+
+        # Detect attention implementation.
+        selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
+        if selected_backend is None:
+            backend_by_env_var: Optional[str] = envs.APHRODITE_ATTENTION_BACKEND
+            if backend_by_env_var is not None:
+                selected_backend = backend_name_to_enum(backend_by_env_var)
+        if selected_backend is None:
+            # For Volta and Turing GPUs, use xformers instead.
+            device_available = current_platform.get_device_capability()[0] >= 8
+            if device_available:
+                from transformers.utils import is_flash_attn_2_available
+                if is_flash_attn_2_available():
+                    self._use_flash_attn = True
+                else:
+                    log_once(
+                    level="WARNING",
+                    message=
+                        "Current Molmo implementation has a bug with "
+                        "`aphrodite-flash-attn` inside vision module, so we use"
+                        " xformers backend instead. You can run `pip install "
+                        "flash-attn to use flash-attention backend."
+                    )
+                    self._use_flash_attn = False
+            else:
+                self._use_flash_attn = False
+        else:
+            if selected_backend == _Backend.FLASH_ATTN:
+                self._use_flash_attn = True
+            elif selected_backend == _Backend.XFORMERS:
+                self._use_flash_attn = False
+            else:
+                raise RuntimeError(
+                    f"Molmo does not support {selected_backend} backend now.")
+
+    def forward(self,
+                inputs_q: torch.Tensor,
+                inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor:
+
+        if inputs_kv is not None:
+            inputs_k = inputs_kv
+            inputs_v = inputs_kv
+        else:
+            inputs_k = inputs_q
+            inputs_v = inputs_q
+
+        xq, _ = self.wq(inputs_q)
+        xk, _ = self.wk(inputs_k)
+        xv, _ = self.wv(inputs_v)
+        q_shape = xq.size()[:-1] + (self.num_heads, self.head_dim)
+        kv_shape = xk.size()[:-1] + (self.num_kv_heads, self.head_dim)
+        xq = xq.view(*q_shape)
+        xk = xk.view(*kv_shape)
+        xv = xv.view(*kv_shape)
+
+        if self._use_flash_attn:
+            from flash_attn import flash_attn_func
+            output = flash_attn_func(xq, xk, xv, dropout_p=0.0, causal=False)
+        else:
+            from xformers import ops as xops
+            output = xops.memory_efficient_attention_forward(xq, xk, xv, p=0)
+
+        output = rearrange(output, "b s h d -> b s (h d)").contiguous()
+        output, _ = self.wo(output)
+
+        return output
+
+
+class ResidualAttentionBlock(nn.Module):
+    """Residual attention block used in Vision Transformer."""
+
+    def __init__(
+        self,
+        config: VisionBackboneConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+    ):
+        super().__init__()
+        self.attention = MultiHeadDotProductAttention(
+            config, quant_config=quant_config)
+        self.feed_forward = ViTMLP(config, quant_config)
+        self.attention_norm = nn.LayerNorm(
+            config.image_emb_dim,
+            eps=config.image_norm_eps,
+        )
+        self.ffn_norm = nn.LayerNorm(
+            config.image_emb_dim,
+            eps=config.image_norm_eps,
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = x + self.attention(self.attention_norm(x))
+        x = x + self.feed_forward(self.ffn_norm(x))
+        return x
+
+
+class BlockCollection(nn.Module):
+    """Collection of residual attention blocks used in Vision Transformer."""
+
+    def __init__(
+        self,
+        config: VisionBackboneConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+    ):
+        super().__init__()
+        self.resblocks = nn.ModuleList([
+            ResidualAttentionBlock(config, quant_config)
+            for _ in range(config.image_num_layers)
+        ])
+
+    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+        hidden_states = []
+        for r in self.resblocks:
+            x = r(x)
+            hidden_states.append(x)
+        return hidden_states
+
+
+def _expand_token(token: torch.Tensor, batch_size: int) -> torch.Tensor:
+    return token.view(1, 1, -1).expand(batch_size, -1, -1)
+
+
+class VisionTransformer(nn.Module):
+    """Vision Transformer used in Vision Backbone."""
+
+    def __init__(
+        self,
+        config: VisionBackboneConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+    ):
+        super().__init__()
+        scale = config.image_emb_dim**-0.5
+        self.patch_num = config.image_num_patch
+        self.class_embedding = nn.Parameter(
+            torch.randn(config.image_emb_dim) * scale)
+        self.num_prefix_tokens: int = NUM_PREFIX_TOKENS
+        self.positional_embedding = nn.Parameter(
+            torch.randn(config.image_num_pos, config.image_emb_dim) * scale)
+        image_patch_size = config.image_patch_size
+        self.patch_embedding = nn.Linear(
+            image_patch_size * image_patch_size * 3,
+            config.image_emb_dim,
+            bias=False,
+        )
+        self.pre_ln = nn.LayerNorm(config.image_emb_dim,
+                                   eps=config.image_norm_eps)
+        self.transformer = BlockCollection(config, quant_config)
+
+    def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
+        cls_emb = self.positional_embedding[0:1]
+        pos_emb = self.positional_embedding[1:]
+
+        pos_emb = pos_emb.reshape(
+            (int(math.sqrt(pos_emb.shape[0])),
+             int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1]))
+
+        (patch_num_0, patch_num_1) = patch_num
+
+        if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
+            # from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
+            pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
+            pos_emb = F.interpolate(
+                pos_emb,
+                size=(patch_num_0, patch_num_1),
+                mode="bicubic",
+                align_corners=False,
+                antialias=True,
+            )
+            pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)
+
+        pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1])
+        x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]],
+                          dim=1).to(x.dtype)
+        return x
+
+    def forward(self,
+                x: torch.Tensor,
+                patch_num: int = None) -> List[torch.Tensor]:
+        """
+        : param x: (batch_size, num_patch, n_pixels)
+        """
+        if patch_num is None:
+            patch_num = self.patch_num
+        B, N, D = x.shape
+
+        x = self.patch_embedding(x)
+
+        # class embeddings and positional embeddings
+        x = torch.cat(
+            [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x],
+            dim=1)
+        x = self.add_pos_emb(x, patch_num)
+
+        x = self.pre_ln(x)
+
+        hidden_states = self.transformer(x)
+        return hidden_states
+
+
+class MolmoAttention(nn.Module):
+    """Molmo's LLM attention."""
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        self.tp_size = get_tensor_model_parallel_world_size()
+        self.total_num_heads = config.num_attention_heads
+
+        assert self.hidden_size % self.total_num_heads == 0
+        assert self.total_num_heads % self.tp_size == 0
+
+        self.num_heads = self.total_num_heads // self.tp_size
+        self.total_num_kv_heads = config.num_key_value_heads \
+            or self.total_num_heads
+        if self.total_num_kv_heads >= self.tp_size:
+            assert self.total_num_kv_heads % self.tp_size == 0
+        else:
+            assert self.tp_size % self.total_num_kv_heads == 0
+
+        self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
+        self.head_dim = self.hidden_size // self.total_num_heads
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+        self.max_position_embeddings = config.max_position_embeddings
+        self.rope_theta = config.rope_theta
+
+        # Attention input projection. Projects x -> (q, k, v)
+        self.qkv_proj = QKVParallelLinear(
+            self.hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=config.qkv_bias,
+            quant_config=quant_config,
+        )
+
+        self.tp_rank: Optional[int] = None
+        self.k_norm: Optional[nn.Module] = None
+        self.q_norm: Optional[nn.Module] = None
+        if config.attention_layer_norm:
+            self.tp_rank = get_tensor_model_parallel_rank()
+            self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim,
+                                  eps=config.layer_norm_eps)
+            self.q_norm = RMSNorm(config.hidden_size,
+                                  eps=config.layer_norm_eps)
+
+        # Rotary embeddings.
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=self.head_dim,
+            max_position=self.max_position_embeddings,
+            base=self.rope_theta,
+        )
+        self.scaling = self.head_dim**-0.5
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              self.scaling,
+                              num_kv_heads=self.num_kv_heads,
+                              cache_config=cache_config,
+                              quant_config=quant_config)
+
+        # Attention output projection.
+        self.o_proj = RowParallelLinear(
+            self.total_num_heads * self.head_dim,
+            self.hidden_size,
+            bias=False,
+            quant_config=quant_config,
+        )
+
+    def _apply_qk_norm(self, q: torch.Tensor,
+                       k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        if self.tp_size > 1:
+            q = tensor_model_parallel_all_gather(q.contiguous())
+            k = tensor_model_parallel_all_gather(k.contiguous())
+        q = self.q_norm.forward_native(q)
+        k = self.k_norm.forward_native(k)
+        if self.tp_size > 1:
+            splitter = partial(split_tensor_along_last_dim,
+                               num_partitions=self.tp_size)
+            q = splitter(q)[self.tp_rank]
+            k = splitter(k)[self.tp_rank]
+        return q, k
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        if self.q_norm is not None and self.k_norm is not None:
+            q, k = self._apply_qk_norm(q, k)
+        q, k = self.rotary_emb(positions, q, k)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
+        output, _ = self.o_proj(attn_output)
+        return output
+
+
+class MolmoMLP(nn.Module):
+    """Molmo's LLM mlp."""
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        input_dim: Optional[int] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        self.intermediate_size = config.intermediate_size // 2
+
+        # Feed-forward input projection.
+        self.gate_up_proj = MergedColumnParallelLinear(
+            input_dim or self.hidden_size,
+            [self.intermediate_size] * 2,
+            bias=False,
+            quant_config=quant_config,
+        )
+
+        # Activation function.
+        self.act_fn = SiluAndMul()
+
+        # Feed-forward output projection.
+        self.down_proj = RowParallelLinear(
+            self.intermediate_size,
+            self.hidden_size,
+            bias=False,
+            quant_config=quant_config,
+        )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+    ) -> torch.Tensor:
+        gate_up, _ = self.gate_up_proj(x)
+        x = self.act_fn(gate_up)
+        x, _ = self.down_proj(x)
+        return x
+
+
+class MolmoDecoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__()
+        # Attention block.
+        self.self_attn = MolmoAttention(config, cache_config, quant_config)
+
+        # MLP block.
+        self.mlp = MolmoMLP(config, quant_config=quant_config)
+
+        # LayerNorm
+        assert config.layer_norm_type == "rms"
+        self.input_layernorm = RMSNorm(config.hidden_size,
+                                       eps=config.layer_norm_eps)
+        self.post_attention_layernorm = RMSNorm(config.hidden_size,
+                                                eps=config.layer_norm_eps)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+        residual: Optional[torch.Tensor],
+    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
+        # Self Attention
+        if residual is None:
+            residual = hidden_states
+            hidden_states = self.input_layernorm(hidden_states)
+        else:
+            hidden_states, residual = self.input_layernorm(
+                hidden_states, residual)
+        hidden_states = self.self_attn(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
+        )
+
+        hidden_states, residual = self.post_attention_layernorm(
+            hidden_states, residual)
+        hidden_states = self.mlp(hidden_states)
+        return hidden_states, residual
+
+
+class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+        residual: Optional[torch.Tensor],
+    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
+        # Self Attention
+        residual = hidden_states
+        hidden_states = self.self_attn(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
+        )
+
+        hidden_states = self.input_layernorm(hidden_states)
+        hidden_states = hidden_states + residual
+        residual = hidden_states
+
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = hidden_states + residual
+        residual = None
+        return hidden_states, residual
+
+
+class MolmoVisionBackbone(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        vision_config: VisionBackboneConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.vit_layers = VIT_LAYERS
+        self.image_num_patch = vision_config.image_num_patch
+        self.llm_patches_per_crop = (
+            (self.image_num_patch[0] + 1) // 2,
+            (self.image_num_patch[1] + 1) // 2,
+        )
+        self.image_vit = VisionTransformer(vision_config,
+                                           quant_config=quant_config)
+        self.num_prefix_tokens = self.image_vit.num_prefix_tokens
+        assert self.num_prefix_tokens in {
+            0, 1
+        }, "Only 0 or 1 prefix tokens are supported"
+        self.image_pooling_2d = MultiHeadDotProductAttention(
+            vision_config,
+            nlayers=len(self.vit_layers),
+            quant_config=quant_config)
+        self.image_projector = MolmoMLP(
+            config,
+            input_dim=vision_config.image_emb_dim,
+            quant_config=quant_config,
+        )
+
+        image_dim = vision_config.image_emb_dim * len(self.vit_layers)
+        self.pad_embed = nn.Parameter(torch.zeros((2, image_dim)))
+
+    @property
+    def dtype(self) -> torch.dtype:
+        return self.image_vit.patch_embedding.weight.dtype
+
+    @property
+    def device(self) -> torch.device:
+        return self.image_vit.patch_embedding.weight.device
+
+    def encode_image(self, images: torch.Tensor) -> torch.Tensor:
+        """
+        : param images: (batch_size, num_crops, num_patch, n_pixels)
+        """
+        B, T, N, D = images.shape
+
+        mask = ~torch.all(
+            images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True)
+
+        images = images.view(B * T, N, D)
+        image_features = self.image_vit(images)
+
+        if self.vit_layers is not None:
+            features = []
+            for layer in self.vit_layers:
+                features.append(image_features[layer])
+            image_features = torch.cat(features, dim=-1)
+        else:
+            image_features = image_features[-1]
+
+        if self.num_prefix_tokens > 0:
+            image_features = image_features[:, 1:]
+
+        image_features = image_features * mask
+        image_features = image_features.view(B, T, N, -1)
+
+        return image_features
+
+    def forward(
+        self, images: torch.Tensor, image_masks: torch.Tensor
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+
+        # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501
+        batch_size, num_image = images.shape[:2]
+        images = images.to(device=self.device, dtype=self.dtype)
+        image_features = self.encode_image(images)
+
+        og_dtype = image_features.dtype
+        assert image_masks is not None
+        pad_embed = self.pad_embed[:, None, None, None, :]
+        all_pad = image_masks == 0
+        partial_pad = torch.logical_and(
+            image_masks < 1,
+            torch.logical_not(all_pad)).to(dtype=torch.float32)
+        all_pad = all_pad.to(dtype=torch.float32)
+        image_features = image_features + pad_embed[0] * torch.unsqueeze(
+            all_pad, -1)
+        image_features = image_features + pad_embed[1] * torch.unsqueeze(
+            partial_pad, -1)
+
+        image_features = image_features.to(og_dtype)
+
+        image_features = image_features.reshape(
+            (batch_size, num_image) + self.image_num_patch + (-1, ), )
+
+        if self.image_num_patch[0] % 2 == 1:
+            # Pad so we can still pool 2x2 patches
+            image_features = F.pad(
+                image_features,
+                (0, 0, 0, 1, 0, 1, 0, 0, 0, 0),
+            )
+
+        # image pooling
+        image_features = rearrange(
+            image_features,
+            'b n (h dh) (w dw) c -> (b n h w) (dh dw) c',
+            dh=2,
+            dw=2,
+        )
+
+        query = image_features.mean(-2, keepdim=True)
+        image_features = self.image_pooling_2d(query, image_features)
+
+        h, w = self.llm_patches_per_crop
+        image_features = image_features.view(batch_size, num_image, h * w, -1)
+
+        image_features = self.image_projector(image_features)
+
+        # image_features: (batch_size, num_image, num_patch, d_model)
+        return image_features
+
+
+class MolmoModel(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
+    ) -> None:
+        super().__init__()
+        self.config = config
+
+        self.embedding_size = config.embedding_size or config.vocab_size
+        self.embedding_size += ADDITIONAL_VOCAB_SIZE
+        self.embed_tokens = VocabParallelEmbedding(
+            self.embedding_size,
+            config.hidden_size,
+            quant_config=quant_config,
+        )
+
+        decoder_layer = MolmoDecoderNormAfterLayer if config.norm_after \
+            else MolmoDecoderLayer
+        self.start_layer, self.end_layer, self.layers = make_layers(
+            config.num_hidden_layers,
+            lambda prefix: decoder_layer(config, cache_config, quant_config),
+            prefix=f"{prefix}.layers",
+        )
+
+        assert config.layer_norm_type == "rms"
+        self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        if get_pp_group().is_first_rank:
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.embed_tokens(input_ids)
+            residual = None
+        else:
+            assert intermediate_tensors is not None
+            hidden_states = intermediate_tensors["hidden_states"]
+            residual = intermediate_tensors["residual"]
+
+        # Apply blocks one-by-one.
+        for i in range(self.start_layer, self.end_layer):
+            layer = self.layers[i]
+            hidden_states, residual = layer(
+                positions,
+                hidden_states,
+                kv_caches[i - self.start_layer],
+                attn_metadata,
+                residual,
+            )
+        if not get_pp_group().is_last_rank:
+            return IntermediateTensors({
+                "hidden_states": hidden_states,
+                "residual": residual
+            })
+        if residual is not None:
+            hidden_states, _ = self.norm(hidden_states, residual)
+        else:
+            hidden_states = self.norm(hidden_states)
+        return hidden_states
+
+
+cached_get_processor = lru_cache(get_processor)
+
+
+def get_num_patches(num_tiles: int, crop_patches: int, left_margin: int,
+                    right_margin: int, pooling_size: int) -> int:
+    crop_window_patches = crop_patches - (left_margin + right_margin)
+    if num_tiles > 1:
+        left_crop_window_patches = (crop_window_patches + left_margin +
+                                    pooling_size -
+                                    1) // pooling_size * pooling_size
+        middle_crop_window_patches = (crop_window_patches + pooling_size -
+                                      1) // pooling_size * pooling_size
+        right_crop_window_patches = (crop_window_patches + right_margin +
+                                     pooling_size -
+                                     1) // pooling_size * pooling_size
+        return left_crop_window_patches + (
+            num_tiles -
+            2) * middle_crop_window_patches + right_crop_window_patches
+    else:
+        single_crop_window_patches = (crop_patches + pooling_size -
+                                      1) // pooling_size * pooling_size
+        return single_crop_window_patches
+
+
+def get_tokens(tiling_h: int, tiling_w: int, crop_patches: int,
+               left_margin: int, right_margin: int, pooling_size: int) -> int:
+    h = get_num_patches(tiling_h, crop_patches, left_margin, right_margin,
+                        pooling_size)
+    w = get_num_patches(tiling_w, crop_patches, left_margin, right_margin,
+                        pooling_size)
+    per_row = w // pooling_size + 1
+    joint = per_row * (h // pooling_size) + 2
+    image_token_length = (crop_patches + pooling_size - 1) // pooling_size
+    resize = (image_token_length + 1) * image_token_length + 2
+    return resize + joint
+
+
+def get_max_tokens(max_crops: int, crop_patches: int, left_margin: int,
+                   right_margin: int, pooling_size: int) -> int:
+    tilings = []
+    for i in range(1, max_crops + 1):
+        for j in range(1, max_crops + 1):
+            if i * j <= max_crops:
+                tilings.append((i, j))
+    tokens = [
+        get_tokens(tilings[i][0], tilings[i][1], crop_patches, left_margin,
+                   right_margin, pooling_size) for i in range(len(tilings))
+    ]
+    return max(tokens)
+
+
+def get_max_molmo_image_tokens(ctx: InputContext) -> int:
+    processor = cached_get_processor(ctx.model_config.model,
+                                     trust_remote_code=True,
+                                     revision=ctx.model_config.code_revision)
+    image_processor = processor.image_processor
+    max_llm_image_tokens = get_max_tokens(
+        image_processor.max_crops,
+        image_processor.base_image_input_size[0] //
+        image_processor.image_patch_size,
+        image_processor.overlap_margins[0],
+        image_processor.overlap_margins[1],
+        2,
+    )
+    return max_llm_image_tokens
+
+
+# NOTE: preprocessing for the image data has been included in the
+# 'input_processor_for_molmo' function
+def image_input_mapper_for_molmo(
+    ctx: InputContext,
+    data: object,
+):
+    return MultiModalInputs(data)
+
+
+def dummy_data_for_molmo(ctx: InputContext, seq_len: int,
+                         mm_counts: Mapping[str, int]):
+    processor = cached_get_processor(ctx.model_config.model,
+                                     trust_remote_code=True,
+                                     revision=ctx.model_config.code_revision)
+    image_processor = processor.image_processor
+
+    base_image_input_d = image_processor.image_patch_size
+    left_margin, right_margin = image_processor.overlap_margins
+    max_crops = image_processor.max_crops
+
+    # Assume: prompt_token_ids always starts with bos_token_id followed image tokens # noqa: E501
+    max_llm_image_tokens = get_max_molmo_image_tokens(ctx)
+    if seq_len - max_llm_image_tokens - 1 < 0:
+        raise RuntimeError(
+            f"Molmo cannot process {max_crops} crops in a prompt, "
+            "please increase max_model_len or reduce number of crops")
+
+    # The vertical image has the maximum number of image tokens due to column tokens. # noqa: E501
+    tiling = (max_crops, 1)
+    total_margin_pixels = base_image_input_d * (right_margin + left_margin)
+    crop_patches = image_processor.base_image_input_size[
+        0] // base_image_input_d
+    crop_window_patches = crop_patches - (right_margin + left_margin)
+    crop_window_size = crop_window_patches * base_image_input_d
+
+    h = crop_window_size * tiling[0] + total_margin_pixels
+    w = crop_window_size * tiling[1] + total_margin_pixels
+
+    dummy_image = Image.new("RGB", (w, h), color="red")
+
+    out = processor.process("dummy prompt", dummy_image)
+
+    token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
+                      out["input_ids"][:1 + max_llm_image_tokens])
+    token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
+                       [0]) * (seq_len - max_llm_image_tokens - 1)
+    dummy_seqdata = SequenceData(token_ids)
+    dummy_imgdata = {
+        "images": out["images"],
+        "image_input_idx": out["image_input_idx"],
+    }
+    if "image_masks" in out:
+        dummy_imgdata["image_masks"] = out["image_masks"]
+    dummy_imgdata["seq_len"] = torch.tensor(seq_len, dtype=torch.long)
+    return dummy_seqdata, {"image": dummy_imgdata}
+
+
+def pad_images(
+    max_total_crops: int,
+    images: torch.Tensor,
+    image_input_idx: torch.Tensor,
+    image_masks: Optional[torch.Tensor] = None,
+):
+    n = max_total_crops - images.shape[0]
+    images = F.pad(images, (0, 0, 0, 0, 0, n), value=-1)
+    image_input_idx = F.pad(image_input_idx, (0, 0, 0, n), value=-1)
+    if image_masks is not None:
+        image_masks = F.pad(image_masks, (0, 0, 0, n), value=-1)
+    return images, image_input_idx, image_masks
+
+
+def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs):
+    prompt = llm_inputs["prompt"]
+    multi_modal_data = llm_inputs.get("multi_modal_data")
+    image = multi_modal_data.get("image")
+    processor = cached_get_processor(ctx.model_config.model,
+                                     trust_remote_code=True,
+                                     revision=ctx.model_config.code_revision)
+
+    # NOTE: message formatting for raw text prompt is only applied for
+    # offline inference; for online inference, the prompt is always in
+    # instruction format and tokenized.
+    if prompt is not None and re.match(r"^User:[\s\S]*?(Assistant:)*$",
+                                       prompt):
+        out = processor.process(prompt, image, message_format="none")
+    elif prompt is not None:
+        out = processor.process(prompt, image)
+    else:
+        out = processor.process(None,
+                                image,
+                                tokens=llm_inputs["prompt_token_ids"])
+
+    image_processor = processor.image_processor
+    max_total_crops = 1 + image_processor.max_crops
+    if image is not None:
+        images, image_input_idx, image_masks = pad_images(
+            max_total_crops,
+            out["images"],
+            out["image_input_idx"],
+            out.get("image_masks"),
+        )
+    else:
+        base_image_input_size = image_processor.base_image_input_size
+        image_patch_size = image_processor.image_patch_size
+        image_num_patch = (
+            base_image_input_size[0] // image_patch_size,
+            base_image_input_size[1] // image_patch_size,
+        )
+        n_pixels = image_patch_size * image_patch_size * 3
+        n_patches = image_num_patch[0] * image_num_patch[1]
+
+        image_length_w = image_processor.image_token_length_w
+        image_length_h = image_processor.image_token_length_h
+        tokens_per_image = image_length_w * image_length_h
+        images = torch.full(
+            (max_total_crops, n_patches, n_pixels),
+            -1,
+            dtype=torch.float32,
+        )
+        image_input_idx = torch.full(
+            (max_total_crops, tokens_per_image),
+            -1,
+            dtype=torch.int32,
+        )
+        if image_processor.image_padding_mask:
+            image_masks = torch.full(
+                (max_total_crops, n_patches),
+                -1,
+                dtype=torch.float32,
+            )
+
+    image_data = dict(
+        images=images,
+        image_input_idx=image_input_idx,
+    )
+    if image_masks is not None:
+        image_data["image_masks"] = image_masks
+
+    image_data["seq_len"] = torch.tensor(len(out["input_ids"]),
+                                         dtype=torch.long)
+
+    multi_modal_data = dict(image=image_data)
+
+    return LLMInputs(
+        prompt_token_ids=out["input_ids"],
+        prompt=llm_inputs["prompt"],
+        multi_modal_data=multi_modal_data,
+    )
+
+
+@MULTIMODAL_REGISTRY.register_image_input_mapper(image_input_mapper_for_molmo)
+@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens)
+@INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo)
+@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo)
+class MolmoForCausalLM(nn.Module, SupportsMultiModal):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        multimodal_config: Optional[MultiModalConfig] = None,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[Mapping[str, Any]] = None,
+    ) -> None:
+        super().__init__()
+
+        self.config = config
+        self.multimodal_config = multimodal_config
+
+        vision_config = VisionBackboneConfig()
+        self.vision_backbone = MolmoVisionBackbone(config, vision_config,
+                                                   quant_config)
+        self.model = MolmoModel(config, cache_config, quant_config)
+
+        if self.config.weight_tying:
+            self.lm_head = self.model.transformer.wte
+        else:
+            self.lm_head = ParallelLMHead(
+                config.embedding_size or config.vocab_size,
+                config.hidden_size,
+                quant_config=quant_config,
+            )
+
+        self.logits_processor = LogitsProcessor(config.embedding_size
+                                                or config.vocab_size)
+        self.sampler = Sampler()
+
+    def _parse_and_validate_image_input(
+        self,
+        **kwargs: object,
+    ) -> Optional[MolmoImageInputs]:
+        images = kwargs.pop("images", None)
+        image_masks = kwargs.pop("image_masks", None)
+        if images is None:
+            return None
+
+        image_input_idx = kwargs.pop("image_input_idx", None)
+        seq_len = kwargs.pop("seq_len", None)
+        if image_input_idx is None:
+            raise ValueError("image_input_idx is required for Molmo model.")
+        if seq_len is None:
+            raise ValueError("seq_len is required for Molmo model.")
+        if not isinstance(seq_len, torch.Tensor):
+            seq_len = torch.tensor(seq_len)
+
+        return MolmoImageInputs(
+            images=images,
+            image_input_idx=image_input_idx,
+            seq_len=seq_len,
+            image_masks=image_masks,
+        )
+
+    def _process_image_input(
+        self,
+        image_input: MolmoImageInputs,
+    ) -> torch.Tensor:
+
+        image_features = self.vision_backbone(
+            images=image_input["images"],
+            image_masks=image_input["image_masks"],
+        )
+
+        return image_features
+
+    def _merge_multimodal_embeddings(
+        self,
+        inputs_embeds: torch.Tensor,
+        image_features: torch.Tensor,
+        image_input_idx: torch.Tensor,
+        seq_len: Union[torch.Tensor, List[torch.Tensor]],
+    ) -> torch.Tensor:
+        batch_size, num_image, num_patch = image_features.shape[:3]
+        assert image_input_idx.shape == (batch_size, num_image, num_patch)
+
+        image_features = image_features.to(inputs_embeds.device)
+        seq_len = seq_len.to(inputs_embeds.device)
+
+        # insert the image feature into the embedding.
+        image_features = image_features.view(batch_size, num_image * num_patch,
+                                             -1)
+        image_input_idx = image_input_idx.view(batch_size,
+                                               num_image * num_patch)
+
+        valid = image_input_idx >= 0
+        image_features = image_features * valid[:, :, None].to(
+            image_features.dtype)
+        image_features = image_features.view(
+            batch_size * num_image * num_patch, -1).contiguous()
+
+        image_input_idx = image_input_idx * valid.to(image_input_idx.dtype)
+        offset = torch.cat(
+            [seq_len.new_zeros(
+                (1)), seq_len.cumsum(dim=0)[:-1]], dim=0)[:, None]
+        image_input_idx = image_input_idx + offset.to(image_input_idx.dtype)
+        image_input_idx = image_input_idx.flatten()[:, None]
+        mat = image_input_idx == torch.arange(
+            seq_len.sum().item(), device=inputs_embeds.device)[None, :]
+        mat = mat.to(image_features.dtype)
+
+        inputs_embeds = inputs_embeds + torch.einsum('nd,nm->md',
+                                                     image_features, mat)
+
+        return inputs_embeds
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor,
+        positions: torch.LongTensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        **kwargs: object,
+    ) -> SamplerOutput:
+
+        image_input = self._parse_and_validate_image_input(**kwargs)
+
+        if image_input is not None:
+            inputs_embeds = self.model.embed_tokens(input_ids)
+            image_features = self._process_image_input(image_input)
+
+            inputs_embeds = self._merge_multimodal_embeddings(
+                inputs_embeds,
+                image_features,
+                image_input["image_input_idx"],
+                image_input["seq_len"],
+            )
+
+            input_ids = None
+        else:
+            inputs_embeds = None
+
+        hidden_states = self.model(
+            input_ids=input_ids,
+            positions=positions,
+            kv_caches=kv_caches,
+            attn_metadata=attn_metadata,
+            inputs_embeds=inputs_embeds,
+        )
+
+        return hidden_states
+
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        logits = self.logits_processor(self.lm_head, hidden_states,
+                                       sampling_metadata)
+        return logits
+
+    def sample(
+        self,
+        logits: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+
+        params_mapping = [
+            ("model.transformer.ln_f.weight", "model.norm.weight"),
+            ("attn_out", "self_attn.o_proj"),
+            ("att_proj", "self_attn.qkv_proj"),
+            ("q_norm", "self_attn.q_norm"),
+            ("k_norm", "self_attn.k_norm"),
+            ("attn_norm", "input_layernorm"),
+            ("ff_norm", "post_attention_layernorm"),
+        ]
+
+        params_dict = dict(self.named_parameters(remove_duplicate=False))
+
+        embedding_weight = dict()
+        projector_weight = dict()
+        for name, loaded_weight in weights:
+            if "rotary_emb.inv_freq" in name:
+                continue
+            if self.config.tie_word_embeddings and "lm_head.weight" in name:
+                continue
+
+            if "wte.embedding" in name:
+                embedding_weight["embedding"] = loaded_weight
+                continue
+
+            if "wte.new_embedding" in name:
+                embedding_weight["new_embedding"] = loaded_weight
+                continue
+
+            if "vision_backbone" in name:
+                if name.startswith("model"):
+                    name = name[len("model."):]
+                if 'image_projector' in name:
+                    if 'w1' in name:
+                        projector_weight['gate_proj'] = loaded_weight
+                    elif 'w3' in name:
+                        projector_weight['up_proj'] = loaded_weight
+                    elif 'w2' in name:
+                        projector_weight['down_proj'] = loaded_weight
+                    else:
+                        raise ValueError(
+                            f"Unexpected projector weight: {name}")
+                    continue
+            else:
+                if "transformer.blocks" in name:
+                    name = name.replace("transformer.blocks", "layers")
+
+                if "ff_proj" in name:
+                    name = name.replace("ff_proj", "mlp.gate_up_proj")
+                    assert 'weight' in name
+                    up_weight, gate_weight = loaded_weight.chunk(2, dim=0)
+                    loaded_weight = torch.cat([gate_weight, up_weight], dim=0)
+
+                elif "ff_out" in name:
+                    if "layers" in name:
+                        name = name.replace("ff_out", "mlp.down_proj")
+                    else:
+                        # lm head
+                        name = name.replace("model.transformer.ff_out",
+                                            "lm_head")
+
+                else:
+                    for (param_name, weight_name) in params_mapping:
+                        if param_name in name:
+                            name = name.replace(param_name, weight_name)
+                            break
+
+            try:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+            except KeyError:
+                raise ValueError(f"Unexpected weight: {name}") from None
+
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)
+
+        gate_up_proj_weight = torch.cat(
+            [projector_weight["gate_proj"], projector_weight["up_proj"]],
+            dim=0)
+        name = "vision_backbone.image_projector.gate_up_proj.weight"
+        param = params_dict[name]
+        weight_loader = getattr(param, "weight_loader", default_weight_loader)
+        weight_loader(param, gate_up_proj_weight)
+
+        down_proj_weight = projector_weight["down_proj"]
+        name = "vision_backbone.image_projector.down_proj.weight"
+        param = params_dict[name]
+        weight_loader = getattr(param, "weight_loader", default_weight_loader)
+        weight_loader(param, down_proj_weight)
+
+        embedding_weight = torch.cat(
+            [embedding_weight["embedding"], embedding_weight["new_embedding"]],
+            dim=0)
+        name = "model.embed_tokens.weight"
+        param = params_dict[name]
+        weight_loader = getattr(param, "weight_loader", default_weight_loader)
+        weight_loader(param, embedding_weight)

+ 1 - 2
aphrodite/modeling/models/qwen2_vl.py

@@ -1127,8 +1127,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
                         continue
                     param = params_dict[name]
                 except KeyError:
-                    print(params_dict.keys())
-                    raise
+                    raise ValueError(f"Unexpected weight: {name}") from None
                 weight_loader = getattr(
                     param, "weight_loader", default_weight_loader
                 )

+ 16 - 0
examples/vision/vision_example.py

@@ -238,6 +238,21 @@ def run_qwen2_vl(question):
     return llm, prompt, stop_token_ids
 
 
+# Molmo
+def run_molmo(question):
+    model_name = "allenai/Molmo-7B-D-0924"
+
+    llm = LLM(
+        model=model_name,
+        trust_remote_code=True,
+        dtype="bfloat16",
+    )
+
+    prompt = question
+    stop_token_ids = None
+    return llm, prompt, stop_token_ids
+
+
 model_example_map = {
     "llava": run_llava,
     "llava-next": run_llava_next,
@@ -251,6 +266,7 @@ model_example_map = {
     "internvl_chat": run_internvl,
     "qwen_vl": run_qwen_vl,
     "qwen2_vl": run_qwen2_vl,
+    "molmo": run_molmo,
 }