import logging import math import re from array import array from dataclasses import dataclass from functools import lru_cache, partial from typing import (Any, Iterable, List, Mapping, Optional, Tuple, TypedDict, Union) import torch from einops import rearrange from PIL import Image from torch import nn from torch.nn import functional as F from transformers import PretrainedConfig import aphrodite.common.envs as envs from aphrodite.attention import Attention, AttentionMetadata from aphrodite.attention.selector import (_Backend, backend_name_to_enum, get_global_forced_attn_backend) from aphrodite.common.config import CacheConfig, MultiModalConfig from aphrodite.common.logger import log_once from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) from aphrodite.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather) from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs from aphrodite.modeling.layers.activation import QuickGELU, SiluAndMul from aphrodite.modeling.layers.layernorm import RMSNorm from aphrodite.modeling.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from aphrodite.modeling.layers.logits_processor import LogitsProcessor from aphrodite.modeling.layers.rotary_embedding import get_rope from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput from aphrodite.modeling.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from aphrodite.modeling.model_loader.weight_utils import default_weight_loader from aphrodite.modeling.models.interfaces import SupportsMultiModal from aphrodite.modeling.models.utils import make_layers from aphrodite.modeling.sampling_metadata import SamplingMetadata from aphrodite.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs from aphrodite.platforms import current_platform from aphrodite.quantization.base_config import QuantizationConfig from aphrodite.transformers_utils.processor import get_processor log = logging.getLogger(__name__) # TODO: hard-coded for now. Consider making it configurable. VIT_LAYERS = [-2, -9] NUM_PREFIX_TOKENS = 1 ADDITIONAL_VOCAB_SIZE = 128 class MolmoImageInputs(TypedDict): images: torch.Tensor """Shape: `(batch_size, num_crops, num_patch, patch_dim)` """ image_input_idx: torch.Tensor """Shape: `(batch_size, num_crops, num_patch)` """ seq_len: torch.Tensor """Shape: `(batch_size, )` """ image_masks: Optional[torch.Tensor] """Shape: `(batch_size, num_crops, num_patch)` """ @dataclass class VisionBackboneConfig: image_default_input_size: Tuple[int, int] = (336, 336) image_patch_size: int = 14 image_pos_patch_size: int = 14 image_emb_dim: int = 1024 image_num_heads: int = 16 image_num_key_value_heads: int = 16 image_num_layers: int = 23 image_mlp_dim: int = 4096 image_mlp_activations: str = "quick_gelu" image_num_pos: int = 577 image_norm_eps: float = 1e-5 def __post_init__(self): self.image_default_input_size = tuple( self.image_default_input_size) # type: ignore[assignment] @property def image_num_patch(self): h, w = self.image_default_input_size return h // self.image_patch_size, w // self.image_patch_size class ViTMLP(nn.Module): """MLP used in Vision Transformer.""" def __init__( self, config: VisionBackboneConfig, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.w1 = ColumnParallelLinear( config.image_emb_dim, config.image_mlp_dim, bias=True, quant_config=quant_config, ) # Activation function. assert config.image_mlp_activations == "quick_gelu" self.act = QuickGELU() self.w2 = RowParallelLinear( config.image_mlp_dim, config.image_emb_dim, bias=True, quant_config=quant_config, ) def forward(self, x: torch.Tensor) -> torch.Tensor: x, _ = self.w1(x) x = self.act(x) x, _ = self.w2(x) return x class MultiHeadDotProductAttention(nn.Module): """Multi-head attention used in Vision Transformer.""" def __init__( self, config: VisionBackboneConfig, use_bias: bool = True, nlayers: int = 1, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = config.image_emb_dim self.total_num_heads = config.image_num_heads tp_size = get_tensor_model_parallel_world_size() assert self.hidden_size % self.total_num_heads == 0 assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.head_dim = self.hidden_size // self.total_num_heads self.total_num_kv_heads = config.image_num_key_value_heads if self.total_num_kv_heads >= tp_size: assert self.total_num_kv_heads % tp_size == 0 else: assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.wq = ColumnParallelLinear( nlayers * self.hidden_size, self.total_num_heads * self.head_dim, bias=use_bias, quant_config=quant_config, ) self.wk = ColumnParallelLinear( nlayers * self.hidden_size, self.total_num_kv_heads * self.head_dim, bias=use_bias, quant_config=quant_config, ) self.wv = ColumnParallelLinear( nlayers * self.hidden_size, self.total_num_kv_heads * self.head_dim, bias=use_bias, quant_config=quant_config, ) self.wo = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=use_bias, quant_config=quant_config, ) # Detect attention implementation. selected_backend: Optional[_Backend] = get_global_forced_attn_backend() if selected_backend is None: backend_by_env_var: Optional[str] = envs.APHRODITE_ATTENTION_BACKEND if backend_by_env_var is not None: selected_backend = backend_name_to_enum(backend_by_env_var) if selected_backend is None: # For Volta and Turing GPUs, use xformers instead. device_available = current_platform.get_device_capability()[0] >= 8 if device_available: from transformers.utils import is_flash_attn_2_available if is_flash_attn_2_available(): self._use_flash_attn = True else: log_once( level="WARNING", message= "Current Molmo implementation has a bug with " "`aphrodite-flash-attn` inside vision module, so we use" " xformers backend instead. You can run `pip install " "flash-attn to use flash-attention backend." ) self._use_flash_attn = False else: self._use_flash_attn = False else: if selected_backend == _Backend.FLASH_ATTN: self._use_flash_attn = True elif selected_backend == _Backend.XFORMERS: self._use_flash_attn = False else: raise RuntimeError( f"Molmo does not support {selected_backend} backend now.") def forward(self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor: if inputs_kv is not None: inputs_k = inputs_kv inputs_v = inputs_kv else: inputs_k = inputs_q inputs_v = inputs_q xq, _ = self.wq(inputs_q) xk, _ = self.wk(inputs_k) xv, _ = self.wv(inputs_v) q_shape = xq.size()[:-1] + (self.num_heads, self.head_dim) kv_shape = xk.size()[:-1] + (self.num_kv_heads, self.head_dim) xq = xq.view(*q_shape) xk = xk.view(*kv_shape) xv = xv.view(*kv_shape) if self._use_flash_attn: from flash_attn import flash_attn_func output = flash_attn_func(xq, xk, xv, dropout_p=0.0, causal=False) else: from xformers import ops as xops output = xops.memory_efficient_attention_forward(xq, xk, xv, p=0) output = rearrange(output, "b s h d -> b s (h d)").contiguous() output, _ = self.wo(output) return output class ResidualAttentionBlock(nn.Module): """Residual attention block used in Vision Transformer.""" def __init__( self, config: VisionBackboneConfig, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.attention = MultiHeadDotProductAttention( config, quant_config=quant_config) self.feed_forward = ViTMLP(config, quant_config) self.attention_norm = nn.LayerNorm( config.image_emb_dim, eps=config.image_norm_eps, ) self.ffn_norm = nn.LayerNorm( config.image_emb_dim, eps=config.image_norm_eps, ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attention(self.attention_norm(x)) x = x + self.feed_forward(self.ffn_norm(x)) return x class BlockCollection(nn.Module): """Collection of residual attention blocks used in Vision Transformer.""" def __init__( self, config: VisionBackboneConfig, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.resblocks = nn.ModuleList([ ResidualAttentionBlock(config, quant_config) for _ in range(config.image_num_layers) ]) def forward(self, x: torch.Tensor) -> List[torch.Tensor]: hidden_states = [] for r in self.resblocks: x = r(x) hidden_states.append(x) return hidden_states def _expand_token(token: torch.Tensor, batch_size: int) -> torch.Tensor: return token.view(1, 1, -1).expand(batch_size, -1, -1) class VisionTransformer(nn.Module): """Vision Transformer used in Vision Backbone.""" def __init__( self, config: VisionBackboneConfig, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() scale = config.image_emb_dim**-0.5 self.patch_num = config.image_num_patch self.class_embedding = nn.Parameter( torch.randn(config.image_emb_dim) * scale) self.num_prefix_tokens: int = NUM_PREFIX_TOKENS self.positional_embedding = nn.Parameter( torch.randn(config.image_num_pos, config.image_emb_dim) * scale) image_patch_size = config.image_patch_size self.patch_embedding = nn.Linear( image_patch_size * image_patch_size * 3, config.image_emb_dim, bias=False, ) self.pre_ln = nn.LayerNorm(config.image_emb_dim, eps=config.image_norm_eps) self.transformer = BlockCollection(config, quant_config) def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor: cls_emb = self.positional_embedding[0:1] pos_emb = self.positional_embedding[1:] pos_emb = pos_emb.reshape( (int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1])) (patch_num_0, patch_num_1) = patch_num if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1: # from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2) pos_emb = F.interpolate( pos_emb, size=(patch_num_0, patch_num_1), mode="bicubic", align_corners=False, antialias=True, ) pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0) pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1]) x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], dim=1).to(x.dtype) return x def forward(self, x: torch.Tensor, patch_num: int = None) -> List[torch.Tensor]: """ : param x: (batch_size, num_patch, n_pixels) """ if patch_num is None: patch_num = self.patch_num B, N, D = x.shape x = self.patch_embedding(x) # class embeddings and positional embeddings x = torch.cat( [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) x = self.add_pos_emb(x, patch_num) x = self.pre_ln(x) hidden_states = self.transformer(x) return hidden_states class MolmoAttention(nn.Module): """Molmo's LLM attention.""" def __init__( self, config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size self.tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads assert self.hidden_size % self.total_num_heads == 0 assert self.total_num_heads % self.tp_size == 0 self.num_heads = self.total_num_heads // self.tp_size self.total_num_kv_heads = config.num_key_value_heads \ or self.total_num_heads if self.total_num_kv_heads >= self.tp_size: assert self.total_num_kv_heads % self.tp_size == 0 else: assert self.tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta # Attention input projection. Projects x -> (q, k, v) self.qkv_proj = QKVParallelLinear( self.hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=config.qkv_bias, quant_config=quant_config, ) self.tp_rank: Optional[int] = None self.k_norm: Optional[nn.Module] = None self.q_norm: Optional[nn.Module] = None if config.attention_layer_norm: self.tp_rank = get_tensor_model_parallel_rank() self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, eps=config.layer_norm_eps) self.q_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) # Rotary embeddings. self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=self.max_position_embeddings, base=self.rope_theta, ) self.scaling = self.head_dim**-0.5 self.attn = Attention(self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config) # Attention output projection. self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=False, quant_config=quant_config, ) def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.tp_size > 1: q = tensor_model_parallel_all_gather(q.contiguous()) k = tensor_model_parallel_all_gather(k.contiguous()) q = self.q_norm.forward_native(q) k = self.k_norm.forward_native(k) if self.tp_size > 1: splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size) q = splitter(q)[self.tp_rank] k = splitter(k)[self.tp_rank] return q, k def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) if self.q_norm is not None and self.k_norm is not None: q, k = self._apply_qk_norm(q, k) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output class MolmoMLP(nn.Module): """Molmo's LLM mlp.""" def __init__( self, config: PretrainedConfig, input_dim: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size // 2 # Feed-forward input projection. self.gate_up_proj = MergedColumnParallelLinear( input_dim or self.hidden_size, [self.intermediate_size] * 2, bias=False, quant_config=quant_config, ) # Activation function. self.act_fn = SiluAndMul() # Feed-forward output projection. self.down_proj = RowParallelLinear( self.intermediate_size, self.hidden_size, bias=False, quant_config=quant_config, ) def forward( self, x: torch.Tensor, ) -> torch.Tensor: gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) return x class MolmoDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() # Attention block. self.self_attn = MolmoAttention(config, cache_config, quant_config) # MLP block. self.mlp = MolmoMLP(config, quant_config=quant_config) # LayerNorm assert config.layer_norm_type == "rms" self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: # Self Attention if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm( hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, ) hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual class MolmoDecoderNormAfterLayer(MolmoDecoderLayer): def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: # Self Attention residual = hidden_states hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, ) hidden_states = self.input_layernorm(hidden_states) hidden_states = hidden_states + residual residual = hidden_states hidden_states = self.mlp(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = hidden_states + residual residual = None return hidden_states, residual class MolmoVisionBackbone(nn.Module): def __init__( self, config: PretrainedConfig, vision_config: VisionBackboneConfig, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.vit_layers = VIT_LAYERS self.image_num_patch = vision_config.image_num_patch self.llm_patches_per_crop = ( (self.image_num_patch[0] + 1) // 2, (self.image_num_patch[1] + 1) // 2, ) self.image_vit = VisionTransformer(vision_config, quant_config=quant_config) self.num_prefix_tokens = self.image_vit.num_prefix_tokens assert self.num_prefix_tokens in { 0, 1 }, "Only 0 or 1 prefix tokens are supported" self.image_pooling_2d = MultiHeadDotProductAttention( vision_config, nlayers=len(self.vit_layers), quant_config=quant_config) self.image_projector = MolmoMLP( config, input_dim=vision_config.image_emb_dim, quant_config=quant_config, ) image_dim = vision_config.image_emb_dim * len(self.vit_layers) self.pad_embed = nn.Parameter(torch.zeros((2, image_dim))) @property def dtype(self) -> torch.dtype: return self.image_vit.patch_embedding.weight.dtype @property def device(self) -> torch.device: return self.image_vit.patch_embedding.weight.device def encode_image(self, images: torch.Tensor) -> torch.Tensor: """ : param images: (batch_size, num_crops, num_patch, n_pixels) """ B, T, N, D = images.shape mask = ~torch.all( images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True) images = images.view(B * T, N, D) image_features = self.image_vit(images) if self.vit_layers is not None: features = [] for layer in self.vit_layers: features.append(image_features[layer]) image_features = torch.cat(features, dim=-1) else: image_features = image_features[-1] if self.num_prefix_tokens > 0: image_features = image_features[:, 1:] image_features = image_features * mask image_features = image_features.view(B, T, N, -1) return image_features def forward( self, images: torch.Tensor, image_masks: torch.Tensor ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501 batch_size, num_image = images.shape[:2] images = images.to(device=self.device, dtype=self.dtype) image_features = self.encode_image(images) og_dtype = image_features.dtype assert image_masks is not None pad_embed = self.pad_embed[:, None, None, None, :] all_pad = image_masks == 0 partial_pad = torch.logical_and( image_masks < 1, torch.logical_not(all_pad)).to(dtype=torch.float32) all_pad = all_pad.to(dtype=torch.float32) image_features = image_features + pad_embed[0] * torch.unsqueeze( all_pad, -1) image_features = image_features + pad_embed[1] * torch.unsqueeze( partial_pad, -1) image_features = image_features.to(og_dtype) image_features = image_features.reshape( (batch_size, num_image) + self.image_num_patch + (-1, ), ) if self.image_num_patch[0] % 2 == 1: # Pad so we can still pool 2x2 patches image_features = F.pad( image_features, (0, 0, 0, 1, 0, 1, 0, 0, 0, 0), ) # image pooling image_features = rearrange( image_features, 'b n (h dh) (w dw) c -> (b n h w) (dh dw) c', dh=2, dw=2, ) query = image_features.mean(-2, keepdim=True) image_features = self.image_pooling_2d(query, image_features) h, w = self.llm_patches_per_crop image_features = image_features.view(batch_size, num_image, h * w, -1) image_features = self.image_projector(image_features) # image_features: (batch_size, num_image, num_patch, d_model) return image_features class MolmoModel(nn.Module): def __init__( self, config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.embedding_size = config.embedding_size or config.vocab_size self.embedding_size += ADDITIONAL_VOCAB_SIZE self.embed_tokens = VocabParallelEmbedding( self.embedding_size, config.hidden_size, quant_config=quant_config, ) decoder_layer = MolmoDecoderNormAfterLayer if config.norm_after \ else MolmoDecoderLayer self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: decoder_layer(config, cache_config, quant_config), prefix=f"{prefix}.layers", ) assert config.layer_norm_type == "rms" self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.embed_tokens(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] # Apply blocks one-by-one. for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, kv_caches[i - self.start_layer], attn_metadata, residual, ) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, "residual": residual }) if residual is not None: hidden_states, _ = self.norm(hidden_states, residual) else: hidden_states = self.norm(hidden_states) return hidden_states cached_get_processor = lru_cache(get_processor) def get_num_patches(num_tiles: int, crop_patches: int, left_margin: int, right_margin: int, pooling_size: int) -> int: crop_window_patches = crop_patches - (left_margin + right_margin) if num_tiles > 1: left_crop_window_patches = (crop_window_patches + left_margin + pooling_size - 1) // pooling_size * pooling_size middle_crop_window_patches = (crop_window_patches + pooling_size - 1) // pooling_size * pooling_size right_crop_window_patches = (crop_window_patches + right_margin + pooling_size - 1) // pooling_size * pooling_size return left_crop_window_patches + ( num_tiles - 2) * middle_crop_window_patches + right_crop_window_patches else: single_crop_window_patches = (crop_patches + pooling_size - 1) // pooling_size * pooling_size return single_crop_window_patches def get_tokens(tiling_h: int, tiling_w: int, crop_patches: int, left_margin: int, right_margin: int, pooling_size: int) -> int: h = get_num_patches(tiling_h, crop_patches, left_margin, right_margin, pooling_size) w = get_num_patches(tiling_w, crop_patches, left_margin, right_margin, pooling_size) per_row = w // pooling_size + 1 joint = per_row * (h // pooling_size) + 2 image_token_length = (crop_patches + pooling_size - 1) // pooling_size resize = (image_token_length + 1) * image_token_length + 2 return resize + joint def get_max_tokens(max_crops: int, crop_patches: int, left_margin: int, right_margin: int, pooling_size: int) -> int: tilings = [] for i in range(1, max_crops + 1): for j in range(1, max_crops + 1): if i * j <= max_crops: tilings.append((i, j)) tokens = [ get_tokens(tilings[i][0], tilings[i][1], crop_patches, left_margin, right_margin, pooling_size) for i in range(len(tilings)) ] return max(tokens) def get_max_molmo_image_tokens(ctx: InputContext) -> int: processor = cached_get_processor(ctx.model_config.model, trust_remote_code=True, revision=ctx.model_config.code_revision) image_processor = processor.image_processor max_llm_image_tokens = get_max_tokens( image_processor.max_crops, image_processor.base_image_input_size[0] // image_processor.image_patch_size, image_processor.overlap_margins[0], image_processor.overlap_margins[1], 2, ) return max_llm_image_tokens # NOTE: preprocessing for the image data has been included in the # 'input_processor_for_molmo' function def image_input_mapper_for_molmo( ctx: InputContext, data: object, ): return MultiModalInputs(data) def dummy_data_for_molmo(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): processor = cached_get_processor(ctx.model_config.model, trust_remote_code=True, revision=ctx.model_config.code_revision) image_processor = processor.image_processor base_image_input_d = image_processor.image_patch_size left_margin, right_margin = image_processor.overlap_margins max_crops = image_processor.max_crops # Assume: prompt_token_ids always starts with bos_token_id followed image tokens # noqa: E501 max_llm_image_tokens = get_max_molmo_image_tokens(ctx) if seq_len - max_llm_image_tokens - 1 < 0: raise RuntimeError( f"Molmo cannot process {max_crops} crops in a prompt, " "please increase max_model_len or reduce number of crops") # The vertical image has the maximum number of image tokens due to column tokens. # noqa: E501 tiling = (max_crops, 1) total_margin_pixels = base_image_input_d * (right_margin + left_margin) crop_patches = image_processor.base_image_input_size[ 0] // base_image_input_d crop_window_patches = crop_patches - (right_margin + left_margin) crop_window_size = crop_window_patches * base_image_input_d h = crop_window_size * tiling[0] + total_margin_pixels w = crop_window_size * tiling[1] + total_margin_pixels dummy_image = Image.new("RGB", (w, h), color="red") out = processor.process("dummy prompt", dummy_image) token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE, out["input_ids"][:1 + max_llm_image_tokens]) token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - max_llm_image_tokens - 1) dummy_seqdata = SequenceData(token_ids) dummy_imgdata = { "images": out["images"], "image_input_idx": out["image_input_idx"], } if "image_masks" in out: dummy_imgdata["image_masks"] = out["image_masks"] dummy_imgdata["seq_len"] = torch.tensor(seq_len, dtype=torch.long) return dummy_seqdata, {"image": dummy_imgdata} def pad_images( max_total_crops: int, images: torch.Tensor, image_input_idx: torch.Tensor, image_masks: Optional[torch.Tensor] = None, ): n = max_total_crops - images.shape[0] images = F.pad(images, (0, 0, 0, 0, 0, n), value=-1) image_input_idx = F.pad(image_input_idx, (0, 0, 0, n), value=-1) if image_masks is not None: image_masks = F.pad(image_masks, (0, 0, 0, n), value=-1) return images, image_input_idx, image_masks def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs): prompt = llm_inputs["prompt"] multi_modal_data = llm_inputs.get("multi_modal_data") image = multi_modal_data.get("image") processor = cached_get_processor(ctx.model_config.model, trust_remote_code=True, revision=ctx.model_config.code_revision) # NOTE: message formatting for raw text prompt is only applied for # offline inference; for online inference, the prompt is always in # instruction format and tokenized. if prompt is not None and re.match(r"^User:[\s\S]*?(Assistant:)*$", prompt): out = processor.process(prompt, image, message_format="none") elif prompt is not None: out = processor.process(prompt, image) else: out = processor.process(None, image, tokens=llm_inputs["prompt_token_ids"]) image_processor = processor.image_processor max_total_crops = 1 + image_processor.max_crops if image is not None: images, image_input_idx, image_masks = pad_images( max_total_crops, out["images"], out["image_input_idx"], out.get("image_masks"), ) else: base_image_input_size = image_processor.base_image_input_size image_patch_size = image_processor.image_patch_size image_num_patch = ( base_image_input_size[0] // image_patch_size, base_image_input_size[1] // image_patch_size, ) n_pixels = image_patch_size * image_patch_size * 3 n_patches = image_num_patch[0] * image_num_patch[1] image_length_w = image_processor.image_token_length_w image_length_h = image_processor.image_token_length_h tokens_per_image = image_length_w * image_length_h images = torch.full( (max_total_crops, n_patches, n_pixels), -1, dtype=torch.float32, ) image_input_idx = torch.full( (max_total_crops, tokens_per_image), -1, dtype=torch.int32, ) if image_processor.image_padding_mask: image_masks = torch.full( (max_total_crops, n_patches), -1, dtype=torch.float32, ) image_data = dict( images=images, image_input_idx=image_input_idx, ) if image_masks is not None: image_data["image_masks"] = image_masks image_data["seq_len"] = torch.tensor(len(out["input_ids"]), dtype=torch.long) multi_modal_data = dict(image=image_data) return LLMInputs( prompt_token_ids=out["input_ids"], prompt=llm_inputs["prompt"], multi_modal_data=multi_modal_data, ) @MULTIMODAL_REGISTRY.register_image_input_mapper(image_input_mapper_for_molmo) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo) @INPUT_REGISTRY.register_input_processor(input_processor_for_molmo) class MolmoForCausalLM(nn.Module, SupportsMultiModal): def __init__( self, config: PretrainedConfig, multimodal_config: Optional[MultiModalConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[Mapping[str, Any]] = None, ) -> None: super().__init__() self.config = config self.multimodal_config = multimodal_config vision_config = VisionBackboneConfig() self.vision_backbone = MolmoVisionBackbone(config, vision_config, quant_config) self.model = MolmoModel(config, cache_config, quant_config) if self.config.weight_tying: self.lm_head = self.model.transformer.wte else: self.lm_head = ParallelLMHead( config.embedding_size or config.vocab_size, config.hidden_size, quant_config=quant_config, ) self.logits_processor = LogitsProcessor(config.embedding_size or config.vocab_size) self.sampler = Sampler() def _parse_and_validate_image_input( self, **kwargs: object, ) -> Optional[MolmoImageInputs]: images = kwargs.pop("images", None) image_masks = kwargs.pop("image_masks", None) if images is None: return None image_input_idx = kwargs.pop("image_input_idx", None) seq_len = kwargs.pop("seq_len", None) if image_input_idx is None: raise ValueError("image_input_idx is required for Molmo model.") if seq_len is None: raise ValueError("seq_len is required for Molmo model.") if not isinstance(seq_len, torch.Tensor): seq_len = torch.tensor(seq_len) return MolmoImageInputs( images=images, image_input_idx=image_input_idx, seq_len=seq_len, image_masks=image_masks, ) def _process_image_input( self, image_input: MolmoImageInputs, ) -> torch.Tensor: image_features = self.vision_backbone( images=image_input["images"], image_masks=image_input["image_masks"], ) return image_features def _merge_multimodal_embeddings( self, inputs_embeds: torch.Tensor, image_features: torch.Tensor, image_input_idx: torch.Tensor, seq_len: Union[torch.Tensor, List[torch.Tensor]], ) -> torch.Tensor: batch_size, num_image, num_patch = image_features.shape[:3] assert image_input_idx.shape == (batch_size, num_image, num_patch) image_features = image_features.to(inputs_embeds.device) seq_len = seq_len.to(inputs_embeds.device) # insert the image feature into the embedding. image_features = image_features.view(batch_size, num_image * num_patch, -1) image_input_idx = image_input_idx.view(batch_size, num_image * num_patch) valid = image_input_idx >= 0 image_features = image_features * valid[:, :, None].to( image_features.dtype) image_features = image_features.view( batch_size * num_image * num_patch, -1).contiguous() image_input_idx = image_input_idx * valid.to(image_input_idx.dtype) offset = torch.cat( [seq_len.new_zeros( (1)), seq_len.cumsum(dim=0)[:-1]], dim=0)[:, None] image_input_idx = image_input_idx + offset.to(image_input_idx.dtype) image_input_idx = image_input_idx.flatten()[:, None] mat = image_input_idx == torch.arange( seq_len.sum().item(), device=inputs_embeds.device)[None, :] mat = mat.to(image_features.dtype) inputs_embeds = inputs_embeds + torch.einsum('nd,nm->md', image_features, mat) return inputs_embeds def forward( self, input_ids: torch.LongTensor, positions: torch.LongTensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, **kwargs: object, ) -> SamplerOutput: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: inputs_embeds = self.model.embed_tokens(input_ids) image_features = self._process_image_input(image_input) inputs_embeds = self._merge_multimodal_embeddings( inputs_embeds, image_features, image_input["image_input_idx"], image_input["seq_len"], ) input_ids = None else: inputs_embeds = None hidden_states = self.model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, inputs_embeds=inputs_embeds, ) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_mapping = [ ("model.transformer.ln_f.weight", "model.norm.weight"), ("attn_out", "self_attn.o_proj"), ("att_proj", "self_attn.qkv_proj"), ("q_norm", "self_attn.q_norm"), ("k_norm", "self_attn.k_norm"), ("attn_norm", "input_layernorm"), ("ff_norm", "post_attention_layernorm"), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) embedding_weight = dict() projector_weight = dict() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if self.config.tie_word_embeddings and "lm_head.weight" in name: continue if "wte.embedding" in name: embedding_weight["embedding"] = loaded_weight continue if "wte.new_embedding" in name: embedding_weight["new_embedding"] = loaded_weight continue if "vision_backbone" in name: if name.startswith("model"): name = name[len("model."):] if 'image_projector' in name: if 'w1' in name: projector_weight['gate_proj'] = loaded_weight elif 'w3' in name: projector_weight['up_proj'] = loaded_weight elif 'w2' in name: projector_weight['down_proj'] = loaded_weight else: raise ValueError( f"Unexpected projector weight: {name}") continue else: if "transformer.blocks" in name: name = name.replace("transformer.blocks", "layers") if "ff_proj" in name: name = name.replace("ff_proj", "mlp.gate_up_proj") assert 'weight' in name up_weight, gate_weight = loaded_weight.chunk(2, dim=0) loaded_weight = torch.cat([gate_weight, up_weight], dim=0) elif "ff_out" in name: if "layers" in name: name = name.replace("ff_out", "mlp.down_proj") else: # lm head name = name.replace("model.transformer.ff_out", "lm_head") else: for (param_name, weight_name) in params_mapping: if param_name in name: name = name.replace(param_name, weight_name) break try: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] except KeyError: raise ValueError(f"Unexpected weight: {name}") from None weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) gate_up_proj_weight = torch.cat( [projector_weight["gate_proj"], projector_weight["up_proj"]], dim=0) name = "vision_backbone.image_projector.gate_up_proj.weight" param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, gate_up_proj_weight) down_proj_weight = projector_weight["down_proj"] name = "vision_backbone.image_projector.down_proj.weight" param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, down_proj_weight) embedding_weight = torch.cat( [embedding_weight["embedding"], embedding_weight["new_embedding"]], dim=0) name = "model.embed_tokens.weight" param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, embedding_weight)