1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294 |
- import logging
- import math
- import re
- from array import array
- from dataclasses import dataclass
- from functools import lru_cache, partial
- from typing import (Any, Iterable, List, Mapping, Optional, Tuple, TypedDict,
- Union)
- import torch
- from einops import rearrange
- from PIL import Image
- from torch import nn
- from torch.nn import functional as F
- from transformers import PretrainedConfig
- import aphrodite.common.envs as envs
- from aphrodite.attention import Attention, AttentionMetadata
- from aphrodite.attention.selector import (_Backend, backend_name_to_enum,
- get_global_forced_attn_backend)
- from aphrodite.common.config import CacheConfig, MultiModalConfig
- from aphrodite.common.logger import log_once
- from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
- IntermediateTensors, SequenceData)
- from aphrodite.distributed import (get_pp_group,
- get_tensor_model_parallel_rank,
- get_tensor_model_parallel_world_size,
- split_tensor_along_last_dim,
- tensor_model_parallel_all_gather)
- from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
- from aphrodite.modeling.layers.activation import QuickGELU, SiluAndMul
- from aphrodite.modeling.layers.layernorm import RMSNorm
- from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
- MergedColumnParallelLinear,
- QKVParallelLinear,
- RowParallelLinear)
- from aphrodite.modeling.layers.logits_processor import LogitsProcessor
- from aphrodite.modeling.layers.rotary_embedding import get_rope
- from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
- from aphrodite.modeling.layers.vocab_parallel_embedding import (
- ParallelLMHead, VocabParallelEmbedding)
- from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
- from aphrodite.modeling.models.interfaces import SupportsMultiModal
- from aphrodite.modeling.models.utils import make_layers
- from aphrodite.modeling.sampling_metadata import SamplingMetadata
- from aphrodite.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
- from aphrodite.platforms import current_platform
- from aphrodite.quantization.base_config import QuantizationConfig
- from aphrodite.transformers_utils.processor import get_processor
- log = logging.getLogger(__name__)
- # TODO: hard-coded for now. Consider making it configurable.
- VIT_LAYERS = [-2, -9]
- NUM_PREFIX_TOKENS = 1
- ADDITIONAL_VOCAB_SIZE = 128
- class MolmoImageInputs(TypedDict):
- images: torch.Tensor
- """Shape:
- `(batch_size, num_crops, num_patch, patch_dim)`
- """
- image_input_idx: torch.Tensor
- """Shape:
- `(batch_size, num_crops, num_patch)`
- """
- seq_len: torch.Tensor
- """Shape:
- `(batch_size, )`
- """
- image_masks: Optional[torch.Tensor]
- """Shape:
- `(batch_size, num_crops, num_patch)`
- """
- @dataclass
- class VisionBackboneConfig:
- image_default_input_size: Tuple[int, int] = (336, 336)
- image_patch_size: int = 14
- image_pos_patch_size: int = 14
- image_emb_dim: int = 1024
- image_num_heads: int = 16
- image_num_key_value_heads: int = 16
- image_num_layers: int = 23
- image_mlp_dim: int = 4096
- image_mlp_activations: str = "quick_gelu"
- image_num_pos: int = 577
- image_norm_eps: float = 1e-5
- def __post_init__(self):
- self.image_default_input_size = tuple(
- self.image_default_input_size) # type: ignore[assignment]
- @property
- def image_num_patch(self):
- h, w = self.image_default_input_size
- return h // self.image_patch_size, w // self.image_patch_size
- class ViTMLP(nn.Module):
- """MLP used in Vision Transformer."""
- def __init__(
- self,
- config: VisionBackboneConfig,
- quant_config: Optional[QuantizationConfig] = None,
- ):
- super().__init__()
- self.w1 = ColumnParallelLinear(
- config.image_emb_dim,
- config.image_mlp_dim,
- bias=True,
- quant_config=quant_config,
- )
- # Activation function.
- assert config.image_mlp_activations == "quick_gelu"
- self.act = QuickGELU()
- self.w2 = RowParallelLinear(
- config.image_mlp_dim,
- config.image_emb_dim,
- bias=True,
- quant_config=quant_config,
- )
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x, _ = self.w1(x)
- x = self.act(x)
- x, _ = self.w2(x)
- return x
- class MultiHeadDotProductAttention(nn.Module):
- """Multi-head attention used in Vision Transformer."""
- def __init__(
- self,
- config: VisionBackboneConfig,
- use_bias: bool = True,
- nlayers: int = 1,
- quant_config: Optional[QuantizationConfig] = None,
- ):
- super().__init__()
- self.hidden_size = config.image_emb_dim
- self.total_num_heads = config.image_num_heads
- tp_size = get_tensor_model_parallel_world_size()
- assert self.hidden_size % self.total_num_heads == 0
- assert self.total_num_heads % tp_size == 0
- self.num_heads = self.total_num_heads // tp_size
- self.head_dim = self.hidden_size // self.total_num_heads
- self.total_num_kv_heads = config.image_num_key_value_heads
- if self.total_num_kv_heads >= tp_size:
- assert self.total_num_kv_heads % tp_size == 0
- else:
- assert tp_size % self.total_num_kv_heads == 0
- self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
- self.wq = ColumnParallelLinear(
- nlayers * self.hidden_size,
- self.total_num_heads * self.head_dim,
- bias=use_bias,
- quant_config=quant_config,
- )
- self.wk = ColumnParallelLinear(
- nlayers * self.hidden_size,
- self.total_num_kv_heads * self.head_dim,
- bias=use_bias,
- quant_config=quant_config,
- )
- self.wv = ColumnParallelLinear(
- nlayers * self.hidden_size,
- self.total_num_kv_heads * self.head_dim,
- bias=use_bias,
- quant_config=quant_config,
- )
- self.wo = RowParallelLinear(
- self.total_num_heads * self.head_dim,
- self.hidden_size,
- bias=use_bias,
- quant_config=quant_config,
- )
- # Detect attention implementation.
- selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
- if selected_backend is None:
- backend_by_env_var: Optional[str] = envs.APHRODITE_ATTENTION_BACKEND
- if backend_by_env_var is not None:
- selected_backend = backend_name_to_enum(backend_by_env_var)
- if selected_backend is None:
- # For Volta and Turing GPUs, use xformers instead.
- device_available = current_platform.get_device_capability()[0] >= 8
- if device_available:
- from transformers.utils import is_flash_attn_2_available
- if is_flash_attn_2_available():
- self._use_flash_attn = True
- else:
- log_once(
- level="WARNING",
- message=
- "Current Molmo implementation has a bug with "
- "`aphrodite-flash-attn` inside vision module, so we use"
- " xformers backend instead. You can run `pip install "
- "flash-attn to use flash-attention backend."
- )
- self._use_flash_attn = False
- else:
- self._use_flash_attn = False
- else:
- if selected_backend == _Backend.FLASH_ATTN:
- self._use_flash_attn = True
- elif selected_backend == _Backend.XFORMERS:
- self._use_flash_attn = False
- else:
- raise RuntimeError(
- f"Molmo does not support {selected_backend} backend now.")
- def forward(self,
- inputs_q: torch.Tensor,
- inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor:
- if inputs_kv is not None:
- inputs_k = inputs_kv
- inputs_v = inputs_kv
- else:
- inputs_k = inputs_q
- inputs_v = inputs_q
- xq, _ = self.wq(inputs_q)
- xk, _ = self.wk(inputs_k)
- xv, _ = self.wv(inputs_v)
- q_shape = xq.size()[:-1] + (self.num_heads, self.head_dim)
- kv_shape = xk.size()[:-1] + (self.num_kv_heads, self.head_dim)
- xq = xq.view(*q_shape)
- xk = xk.view(*kv_shape)
- xv = xv.view(*kv_shape)
- if self._use_flash_attn:
- from flash_attn import flash_attn_func
- output = flash_attn_func(xq, xk, xv, dropout_p=0.0, causal=False)
- else:
- from xformers import ops as xops
- output = xops.memory_efficient_attention_forward(xq, xk, xv, p=0)
- output = rearrange(output, "b s h d -> b s (h d)").contiguous()
- output, _ = self.wo(output)
- return output
- class ResidualAttentionBlock(nn.Module):
- """Residual attention block used in Vision Transformer."""
- def __init__(
- self,
- config: VisionBackboneConfig,
- quant_config: Optional[QuantizationConfig] = None,
- ):
- super().__init__()
- self.attention = MultiHeadDotProductAttention(
- config, quant_config=quant_config)
- self.feed_forward = ViTMLP(config, quant_config)
- self.attention_norm = nn.LayerNorm(
- config.image_emb_dim,
- eps=config.image_norm_eps,
- )
- self.ffn_norm = nn.LayerNorm(
- config.image_emb_dim,
- eps=config.image_norm_eps,
- )
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = x + self.attention(self.attention_norm(x))
- x = x + self.feed_forward(self.ffn_norm(x))
- return x
- class BlockCollection(nn.Module):
- """Collection of residual attention blocks used in Vision Transformer."""
- def __init__(
- self,
- config: VisionBackboneConfig,
- quant_config: Optional[QuantizationConfig] = None,
- ):
- super().__init__()
- self.resblocks = nn.ModuleList([
- ResidualAttentionBlock(config, quant_config)
- for _ in range(config.image_num_layers)
- ])
- def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
- hidden_states = []
- for r in self.resblocks:
- x = r(x)
- hidden_states.append(x)
- return hidden_states
- def _expand_token(token: torch.Tensor, batch_size: int) -> torch.Tensor:
- return token.view(1, 1, -1).expand(batch_size, -1, -1)
- class VisionTransformer(nn.Module):
- """Vision Transformer used in Vision Backbone."""
- def __init__(
- self,
- config: VisionBackboneConfig,
- quant_config: Optional[QuantizationConfig] = None,
- ):
- super().__init__()
- scale = config.image_emb_dim**-0.5
- self.patch_num = config.image_num_patch
- self.class_embedding = nn.Parameter(
- torch.randn(config.image_emb_dim) * scale)
- self.num_prefix_tokens: int = NUM_PREFIX_TOKENS
- self.positional_embedding = nn.Parameter(
- torch.randn(config.image_num_pos, config.image_emb_dim) * scale)
- image_patch_size = config.image_patch_size
- self.patch_embedding = nn.Linear(
- image_patch_size * image_patch_size * 3,
- config.image_emb_dim,
- bias=False,
- )
- self.pre_ln = nn.LayerNorm(config.image_emb_dim,
- eps=config.image_norm_eps)
- self.transformer = BlockCollection(config, quant_config)
- def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
- cls_emb = self.positional_embedding[0:1]
- pos_emb = self.positional_embedding[1:]
- pos_emb = pos_emb.reshape(
- (int(math.sqrt(pos_emb.shape[0])),
- int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1]))
- (patch_num_0, patch_num_1) = patch_num
- if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
- # from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
- pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
- pos_emb = F.interpolate(
- pos_emb,
- size=(patch_num_0, patch_num_1),
- mode="bicubic",
- align_corners=False,
- antialias=True,
- )
- pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)
- pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1])
- x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]],
- dim=1).to(x.dtype)
- return x
- def forward(self,
- x: torch.Tensor,
- patch_num: int = None) -> List[torch.Tensor]:
- """
- : param x: (batch_size, num_patch, n_pixels)
- """
- if patch_num is None:
- patch_num = self.patch_num
- B, N, D = x.shape
- x = self.patch_embedding(x)
- # class embeddings and positional embeddings
- x = torch.cat(
- [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x],
- dim=1)
- x = self.add_pos_emb(x, patch_num)
- x = self.pre_ln(x)
- hidden_states = self.transformer(x)
- return hidden_states
- class MolmoAttention(nn.Module):
- """Molmo's LLM attention."""
- def __init__(
- self,
- config: PretrainedConfig,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None,
- ) -> None:
- super().__init__()
- self.hidden_size = config.hidden_size
- self.tp_size = get_tensor_model_parallel_world_size()
- self.total_num_heads = config.num_attention_heads
- assert self.hidden_size % self.total_num_heads == 0
- assert self.total_num_heads % self.tp_size == 0
- self.num_heads = self.total_num_heads // self.tp_size
- self.total_num_kv_heads = config.num_key_value_heads \
- or self.total_num_heads
- if self.total_num_kv_heads >= self.tp_size:
- assert self.total_num_kv_heads % self.tp_size == 0
- else:
- assert self.tp_size % self.total_num_kv_heads == 0
- self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
- self.head_dim = self.hidden_size // self.total_num_heads
- self.q_size = self.num_heads * self.head_dim
- self.kv_size = self.num_kv_heads * self.head_dim
- self.max_position_embeddings = config.max_position_embeddings
- self.rope_theta = config.rope_theta
- # Attention input projection. Projects x -> (q, k, v)
- self.qkv_proj = QKVParallelLinear(
- self.hidden_size,
- self.head_dim,
- self.total_num_heads,
- self.total_num_kv_heads,
- bias=config.qkv_bias,
- quant_config=quant_config,
- )
- self.tp_rank: Optional[int] = None
- self.k_norm: Optional[nn.Module] = None
- self.q_norm: Optional[nn.Module] = None
- if config.attention_layer_norm:
- self.tp_rank = get_tensor_model_parallel_rank()
- self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim,
- eps=config.layer_norm_eps)
- self.q_norm = RMSNorm(config.hidden_size,
- eps=config.layer_norm_eps)
- # Rotary embeddings.
- self.rotary_emb = get_rope(
- self.head_dim,
- rotary_dim=self.head_dim,
- max_position=self.max_position_embeddings,
- base=self.rope_theta,
- )
- self.scaling = self.head_dim**-0.5
- self.attn = Attention(self.num_heads,
- self.head_dim,
- self.scaling,
- num_kv_heads=self.num_kv_heads,
- cache_config=cache_config,
- quant_config=quant_config)
- # Attention output projection.
- self.o_proj = RowParallelLinear(
- self.total_num_heads * self.head_dim,
- self.hidden_size,
- bias=False,
- quant_config=quant_config,
- )
- def _apply_qk_norm(self, q: torch.Tensor,
- k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- if self.tp_size > 1:
- q = tensor_model_parallel_all_gather(q.contiguous())
- k = tensor_model_parallel_all_gather(k.contiguous())
- q = self.q_norm.forward_native(q)
- k = self.k_norm.forward_native(k)
- if self.tp_size > 1:
- splitter = partial(split_tensor_along_last_dim,
- num_partitions=self.tp_size)
- q = splitter(q)[self.tp_rank]
- k = splitter(k)[self.tp_rank]
- return q, k
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: AttentionMetadata,
- ) -> torch.Tensor:
- qkv, _ = self.qkv_proj(hidden_states)
- q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
- if self.q_norm is not None and self.k_norm is not None:
- q, k = self._apply_qk_norm(q, k)
- q, k = self.rotary_emb(positions, q, k)
- attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
- output, _ = self.o_proj(attn_output)
- return output
- class MolmoMLP(nn.Module):
- """Molmo's LLM mlp."""
- def __init__(
- self,
- config: PretrainedConfig,
- input_dim: Optional[int] = None,
- quant_config: Optional[QuantizationConfig] = None,
- ) -> None:
- super().__init__()
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size // 2
- # Feed-forward input projection.
- self.gate_up_proj = MergedColumnParallelLinear(
- input_dim or self.hidden_size,
- [self.intermediate_size] * 2,
- bias=False,
- quant_config=quant_config,
- )
- # Activation function.
- self.act_fn = SiluAndMul()
- # Feed-forward output projection.
- self.down_proj = RowParallelLinear(
- self.intermediate_size,
- self.hidden_size,
- bias=False,
- quant_config=quant_config,
- )
- def forward(
- self,
- x: torch.Tensor,
- ) -> torch.Tensor:
- gate_up, _ = self.gate_up_proj(x)
- x = self.act_fn(gate_up)
- x, _ = self.down_proj(x)
- return x
- class MolmoDecoderLayer(nn.Module):
- def __init__(
- self,
- config: PretrainedConfig,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None,
- ) -> None:
- super().__init__()
- # Attention block.
- self.self_attn = MolmoAttention(config, cache_config, quant_config)
- # MLP block.
- self.mlp = MolmoMLP(config, quant_config=quant_config)
- # LayerNorm
- assert config.layer_norm_type == "rms"
- self.input_layernorm = RMSNorm(config.hidden_size,
- eps=config.layer_norm_eps)
- self.post_attention_layernorm = RMSNorm(config.hidden_size,
- eps=config.layer_norm_eps)
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: AttentionMetadata,
- residual: Optional[torch.Tensor],
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
- # Self Attention
- if residual is None:
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- else:
- hidden_states, residual = self.input_layernorm(
- hidden_states, residual)
- hidden_states = self.self_attn(
- positions=positions,
- hidden_states=hidden_states,
- kv_cache=kv_cache,
- attn_metadata=attn_metadata,
- )
- hidden_states, residual = self.post_attention_layernorm(
- hidden_states, residual)
- hidden_states = self.mlp(hidden_states)
- return hidden_states, residual
- class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: AttentionMetadata,
- residual: Optional[torch.Tensor],
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
- # Self Attention
- residual = hidden_states
- hidden_states = self.self_attn(
- positions=positions,
- hidden_states=hidden_states,
- kv_cache=kv_cache,
- attn_metadata=attn_metadata,
- )
- hidden_states = self.input_layernorm(hidden_states)
- hidden_states = hidden_states + residual
- residual = hidden_states
- hidden_states = self.mlp(hidden_states)
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = hidden_states + residual
- residual = None
- return hidden_states, residual
- class MolmoVisionBackbone(nn.Module):
- def __init__(
- self,
- config: PretrainedConfig,
- vision_config: VisionBackboneConfig,
- quant_config: Optional[QuantizationConfig] = None,
- ) -> None:
- super().__init__()
- self.vit_layers = VIT_LAYERS
- self.image_num_patch = vision_config.image_num_patch
- self.llm_patches_per_crop = (
- (self.image_num_patch[0] + 1) // 2,
- (self.image_num_patch[1] + 1) // 2,
- )
- self.image_vit = VisionTransformer(vision_config,
- quant_config=quant_config)
- self.num_prefix_tokens = self.image_vit.num_prefix_tokens
- assert self.num_prefix_tokens in {
- 0, 1
- }, "Only 0 or 1 prefix tokens are supported"
- self.image_pooling_2d = MultiHeadDotProductAttention(
- vision_config,
- nlayers=len(self.vit_layers),
- quant_config=quant_config)
- self.image_projector = MolmoMLP(
- config,
- input_dim=vision_config.image_emb_dim,
- quant_config=quant_config,
- )
- image_dim = vision_config.image_emb_dim * len(self.vit_layers)
- self.pad_embed = nn.Parameter(torch.zeros((2, image_dim)))
- @property
- def dtype(self) -> torch.dtype:
- return self.image_vit.patch_embedding.weight.dtype
- @property
- def device(self) -> torch.device:
- return self.image_vit.patch_embedding.weight.device
- def encode_image(self, images: torch.Tensor) -> torch.Tensor:
- """
- : param images: (batch_size, num_crops, num_patch, n_pixels)
- """
- B, T, N, D = images.shape
- mask = ~torch.all(
- images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True)
- images = images.view(B * T, N, D)
- image_features = self.image_vit(images)
- if self.vit_layers is not None:
- features = []
- for layer in self.vit_layers:
- features.append(image_features[layer])
- image_features = torch.cat(features, dim=-1)
- else:
- image_features = image_features[-1]
- if self.num_prefix_tokens > 0:
- image_features = image_features[:, 1:]
- image_features = image_features * mask
- image_features = image_features.view(B, T, N, -1)
- return image_features
- def forward(
- self, images: torch.Tensor, image_masks: torch.Tensor
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
- # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501
- batch_size, num_image = images.shape[:2]
- images = images.to(device=self.device, dtype=self.dtype)
- image_features = self.encode_image(images)
- og_dtype = image_features.dtype
- assert image_masks is not None
- pad_embed = self.pad_embed[:, None, None, None, :]
- all_pad = image_masks == 0
- partial_pad = torch.logical_and(
- image_masks < 1,
- torch.logical_not(all_pad)).to(dtype=torch.float32)
- all_pad = all_pad.to(dtype=torch.float32)
- image_features = image_features + pad_embed[0] * torch.unsqueeze(
- all_pad, -1)
- image_features = image_features + pad_embed[1] * torch.unsqueeze(
- partial_pad, -1)
- image_features = image_features.to(og_dtype)
- image_features = image_features.reshape(
- (batch_size, num_image) + self.image_num_patch + (-1, ), )
- if self.image_num_patch[0] % 2 == 1:
- # Pad so we can still pool 2x2 patches
- image_features = F.pad(
- image_features,
- (0, 0, 0, 1, 0, 1, 0, 0, 0, 0),
- )
- # image pooling
- image_features = rearrange(
- image_features,
- 'b n (h dh) (w dw) c -> (b n h w) (dh dw) c',
- dh=2,
- dw=2,
- )
- query = image_features.mean(-2, keepdim=True)
- image_features = self.image_pooling_2d(query, image_features)
- h, w = self.llm_patches_per_crop
- image_features = image_features.view(batch_size, num_image, h * w, -1)
- image_features = self.image_projector(image_features)
- # image_features: (batch_size, num_image, num_patch, d_model)
- return image_features
- class MolmoModel(nn.Module):
- def __init__(
- self,
- config: PretrainedConfig,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None,
- prefix: str = "",
- ) -> None:
- super().__init__()
- self.config = config
- self.embedding_size = config.embedding_size or config.vocab_size
- self.embedding_size += ADDITIONAL_VOCAB_SIZE
- self.embed_tokens = VocabParallelEmbedding(
- self.embedding_size,
- config.hidden_size,
- quant_config=quant_config,
- )
- decoder_layer = MolmoDecoderNormAfterLayer if config.norm_after \
- else MolmoDecoderLayer
- self.start_layer, self.end_layer, self.layers = make_layers(
- config.num_hidden_layers,
- lambda prefix: decoder_layer(config, cache_config, quant_config),
- prefix=f"{prefix}.layers",
- )
- assert config.layer_norm_type == "rms"
- self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps)
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- intermediate_tensors: Optional[IntermediateTensors] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- if get_pp_group().is_first_rank:
- if inputs_embeds is not None:
- hidden_states = inputs_embeds
- else:
- hidden_states = self.embed_tokens(input_ids)
- residual = None
- else:
- assert intermediate_tensors is not None
- hidden_states = intermediate_tensors["hidden_states"]
- residual = intermediate_tensors["residual"]
- # Apply blocks one-by-one.
- for i in range(self.start_layer, self.end_layer):
- layer = self.layers[i]
- hidden_states, residual = layer(
- positions,
- hidden_states,
- kv_caches[i - self.start_layer],
- attn_metadata,
- residual,
- )
- if not get_pp_group().is_last_rank:
- return IntermediateTensors({
- "hidden_states": hidden_states,
- "residual": residual
- })
- if residual is not None:
- hidden_states, _ = self.norm(hidden_states, residual)
- else:
- hidden_states = self.norm(hidden_states)
- return hidden_states
- cached_get_processor = lru_cache(get_processor)
- def get_num_patches(num_tiles: int, crop_patches: int, left_margin: int,
- right_margin: int, pooling_size: int) -> int:
- crop_window_patches = crop_patches - (left_margin + right_margin)
- if num_tiles > 1:
- left_crop_window_patches = (crop_window_patches + left_margin +
- pooling_size -
- 1) // pooling_size * pooling_size
- middle_crop_window_patches = (crop_window_patches + pooling_size -
- 1) // pooling_size * pooling_size
- right_crop_window_patches = (crop_window_patches + right_margin +
- pooling_size -
- 1) // pooling_size * pooling_size
- return left_crop_window_patches + (
- num_tiles -
- 2) * middle_crop_window_patches + right_crop_window_patches
- else:
- single_crop_window_patches = (crop_patches + pooling_size -
- 1) // pooling_size * pooling_size
- return single_crop_window_patches
- def get_tokens(tiling_h: int, tiling_w: int, crop_patches: int,
- left_margin: int, right_margin: int, pooling_size: int) -> int:
- h = get_num_patches(tiling_h, crop_patches, left_margin, right_margin,
- pooling_size)
- w = get_num_patches(tiling_w, crop_patches, left_margin, right_margin,
- pooling_size)
- per_row = w // pooling_size + 1
- joint = per_row * (h // pooling_size) + 2
- image_token_length = (crop_patches + pooling_size - 1) // pooling_size
- resize = (image_token_length + 1) * image_token_length + 2
- return resize + joint
- def get_max_tokens(max_crops: int, crop_patches: int, left_margin: int,
- right_margin: int, pooling_size: int) -> int:
- tilings = []
- for i in range(1, max_crops + 1):
- for j in range(1, max_crops + 1):
- if i * j <= max_crops:
- tilings.append((i, j))
- tokens = [
- get_tokens(tilings[i][0], tilings[i][1], crop_patches, left_margin,
- right_margin, pooling_size) for i in range(len(tilings))
- ]
- return max(tokens)
- def get_max_molmo_image_tokens(ctx: InputContext) -> int:
- processor = cached_get_processor(ctx.model_config.model,
- trust_remote_code=True,
- revision=ctx.model_config.code_revision)
- image_processor = processor.image_processor
- max_llm_image_tokens = get_max_tokens(
- image_processor.max_crops,
- image_processor.base_image_input_size[0] //
- image_processor.image_patch_size,
- image_processor.overlap_margins[0],
- image_processor.overlap_margins[1],
- 2,
- )
- return max_llm_image_tokens
- # NOTE: preprocessing for the image data has been included in the
- # 'input_processor_for_molmo' function
- def image_input_mapper_for_molmo(
- ctx: InputContext,
- data: object,
- ):
- return MultiModalInputs(data)
- def dummy_data_for_molmo(ctx: InputContext, seq_len: int,
- mm_counts: Mapping[str, int]):
- processor = cached_get_processor(ctx.model_config.model,
- trust_remote_code=True,
- revision=ctx.model_config.code_revision)
- image_processor = processor.image_processor
- base_image_input_d = image_processor.image_patch_size
- left_margin, right_margin = image_processor.overlap_margins
- max_crops = image_processor.max_crops
- # Assume: prompt_token_ids always starts with bos_token_id followed image tokens # noqa: E501
- max_llm_image_tokens = get_max_molmo_image_tokens(ctx)
- if seq_len - max_llm_image_tokens - 1 < 0:
- raise RuntimeError(
- f"Molmo cannot process {max_crops} crops in a prompt, "
- "please increase max_model_len or reduce number of crops")
- # The vertical image has the maximum number of image tokens due to column tokens. # noqa: E501
- tiling = (max_crops, 1)
- total_margin_pixels = base_image_input_d * (right_margin + left_margin)
- crop_patches = image_processor.base_image_input_size[
- 0] // base_image_input_d
- crop_window_patches = crop_patches - (right_margin + left_margin)
- crop_window_size = crop_window_patches * base_image_input_d
- h = crop_window_size * tiling[0] + total_margin_pixels
- w = crop_window_size * tiling[1] + total_margin_pixels
- dummy_image = Image.new("RGB", (w, h), color="red")
- out = processor.process("dummy prompt", dummy_image)
- token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
- out["input_ids"][:1 + max_llm_image_tokens])
- token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
- [0]) * (seq_len - max_llm_image_tokens - 1)
- dummy_seqdata = SequenceData(token_ids)
- dummy_imgdata = {
- "images": out["images"],
- "image_input_idx": out["image_input_idx"],
- }
- if "image_masks" in out:
- dummy_imgdata["image_masks"] = out["image_masks"]
- dummy_imgdata["seq_len"] = torch.tensor(seq_len, dtype=torch.long)
- return dummy_seqdata, {"image": dummy_imgdata}
- def pad_images(
- max_total_crops: int,
- images: torch.Tensor,
- image_input_idx: torch.Tensor,
- image_masks: Optional[torch.Tensor] = None,
- ):
- n = max_total_crops - images.shape[0]
- images = F.pad(images, (0, 0, 0, 0, 0, n), value=-1)
- image_input_idx = F.pad(image_input_idx, (0, 0, 0, n), value=-1)
- if image_masks is not None:
- image_masks = F.pad(image_masks, (0, 0, 0, n), value=-1)
- return images, image_input_idx, image_masks
- def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs):
- prompt = llm_inputs["prompt"]
- multi_modal_data = llm_inputs.get("multi_modal_data")
- image = multi_modal_data.get("image")
- processor = cached_get_processor(ctx.model_config.model,
- trust_remote_code=True,
- revision=ctx.model_config.code_revision)
- # NOTE: message formatting for raw text prompt is only applied for
- # offline inference; for online inference, the prompt is always in
- # instruction format and tokenized.
- if prompt is not None and re.match(r"^User:[\s\S]*?(Assistant:)*$",
- prompt):
- out = processor.process(prompt, image, message_format="none")
- elif prompt is not None:
- out = processor.process(prompt, image)
- else:
- out = processor.process(None,
- image,
- tokens=llm_inputs["prompt_token_ids"])
- image_processor = processor.image_processor
- max_total_crops = 1 + image_processor.max_crops
- if image is not None:
- images, image_input_idx, image_masks = pad_images(
- max_total_crops,
- out["images"],
- out["image_input_idx"],
- out.get("image_masks"),
- )
- else:
- base_image_input_size = image_processor.base_image_input_size
- image_patch_size = image_processor.image_patch_size
- image_num_patch = (
- base_image_input_size[0] // image_patch_size,
- base_image_input_size[1] // image_patch_size,
- )
- n_pixels = image_patch_size * image_patch_size * 3
- n_patches = image_num_patch[0] * image_num_patch[1]
- image_length_w = image_processor.image_token_length_w
- image_length_h = image_processor.image_token_length_h
- tokens_per_image = image_length_w * image_length_h
- images = torch.full(
- (max_total_crops, n_patches, n_pixels),
- -1,
- dtype=torch.float32,
- )
- image_input_idx = torch.full(
- (max_total_crops, tokens_per_image),
- -1,
- dtype=torch.int32,
- )
- if image_processor.image_padding_mask:
- image_masks = torch.full(
- (max_total_crops, n_patches),
- -1,
- dtype=torch.float32,
- )
- image_data = dict(
- images=images,
- image_input_idx=image_input_idx,
- )
- if image_masks is not None:
- image_data["image_masks"] = image_masks
- image_data["seq_len"] = torch.tensor(len(out["input_ids"]),
- dtype=torch.long)
- multi_modal_data = dict(image=image_data)
- return LLMInputs(
- prompt_token_ids=out["input_ids"],
- prompt=llm_inputs["prompt"],
- multi_modal_data=multi_modal_data,
- )
- @MULTIMODAL_REGISTRY.register_image_input_mapper(image_input_mapper_for_molmo)
- @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens)
- @INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo)
- @INPUT_REGISTRY.register_input_processor(input_processor_for_molmo)
- class MolmoForCausalLM(nn.Module, SupportsMultiModal):
- def __init__(
- self,
- config: PretrainedConfig,
- multimodal_config: Optional[MultiModalConfig] = None,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[Mapping[str, Any]] = None,
- ) -> None:
- super().__init__()
- self.config = config
- self.multimodal_config = multimodal_config
- vision_config = VisionBackboneConfig()
- self.vision_backbone = MolmoVisionBackbone(config, vision_config,
- quant_config)
- self.model = MolmoModel(config, cache_config, quant_config)
- if self.config.weight_tying:
- self.lm_head = self.model.transformer.wte
- else:
- self.lm_head = ParallelLMHead(
- config.embedding_size or config.vocab_size,
- config.hidden_size,
- quant_config=quant_config,
- )
- self.logits_processor = LogitsProcessor(config.embedding_size
- or config.vocab_size)
- self.sampler = Sampler()
- def _parse_and_validate_image_input(
- self,
- **kwargs: object,
- ) -> Optional[MolmoImageInputs]:
- images = kwargs.pop("images", None)
- image_masks = kwargs.pop("image_masks", None)
- if images is None:
- return None
- image_input_idx = kwargs.pop("image_input_idx", None)
- seq_len = kwargs.pop("seq_len", None)
- if image_input_idx is None:
- raise ValueError("image_input_idx is required for Molmo model.")
- if seq_len is None:
- raise ValueError("seq_len is required for Molmo model.")
- if not isinstance(seq_len, torch.Tensor):
- seq_len = torch.tensor(seq_len)
- return MolmoImageInputs(
- images=images,
- image_input_idx=image_input_idx,
- seq_len=seq_len,
- image_masks=image_masks,
- )
- def _process_image_input(
- self,
- image_input: MolmoImageInputs,
- ) -> torch.Tensor:
- image_features = self.vision_backbone(
- images=image_input["images"],
- image_masks=image_input["image_masks"],
- )
- return image_features
- def _merge_multimodal_embeddings(
- self,
- inputs_embeds: torch.Tensor,
- image_features: torch.Tensor,
- image_input_idx: torch.Tensor,
- seq_len: Union[torch.Tensor, List[torch.Tensor]],
- ) -> torch.Tensor:
- batch_size, num_image, num_patch = image_features.shape[:3]
- assert image_input_idx.shape == (batch_size, num_image, num_patch)
- image_features = image_features.to(inputs_embeds.device)
- seq_len = seq_len.to(inputs_embeds.device)
- # insert the image feature into the embedding.
- image_features = image_features.view(batch_size, num_image * num_patch,
- -1)
- image_input_idx = image_input_idx.view(batch_size,
- num_image * num_patch)
- valid = image_input_idx >= 0
- image_features = image_features * valid[:, :, None].to(
- image_features.dtype)
- image_features = image_features.view(
- batch_size * num_image * num_patch, -1).contiguous()
- image_input_idx = image_input_idx * valid.to(image_input_idx.dtype)
- offset = torch.cat(
- [seq_len.new_zeros(
- (1)), seq_len.cumsum(dim=0)[:-1]], dim=0)[:, None]
- image_input_idx = image_input_idx + offset.to(image_input_idx.dtype)
- image_input_idx = image_input_idx.flatten()[:, None]
- mat = image_input_idx == torch.arange(
- seq_len.sum().item(), device=inputs_embeds.device)[None, :]
- mat = mat.to(image_features.dtype)
- inputs_embeds = inputs_embeds + torch.einsum('nd,nm->md',
- image_features, mat)
- return inputs_embeds
- def forward(
- self,
- input_ids: torch.LongTensor,
- positions: torch.LongTensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- **kwargs: object,
- ) -> SamplerOutput:
- image_input = self._parse_and_validate_image_input(**kwargs)
- if image_input is not None:
- inputs_embeds = self.model.embed_tokens(input_ids)
- image_features = self._process_image_input(image_input)
- inputs_embeds = self._merge_multimodal_embeddings(
- inputs_embeds,
- image_features,
- image_input["image_input_idx"],
- image_input["seq_len"],
- )
- input_ids = None
- else:
- inputs_embeds = None
- hidden_states = self.model(
- input_ids=input_ids,
- positions=positions,
- kv_caches=kv_caches,
- attn_metadata=attn_metadata,
- inputs_embeds=inputs_embeds,
- )
- return hidden_states
- def compute_logits(self, hidden_states: torch.Tensor,
- sampling_metadata: SamplingMetadata) -> torch.Tensor:
- logits = self.logits_processor(self.lm_head, hidden_states,
- sampling_metadata)
- return logits
- def sample(
- self,
- logits: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[SamplerOutput]:
- next_tokens = self.sampler(logits, sampling_metadata)
- return next_tokens
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
- params_mapping = [
- ("model.transformer.ln_f.weight", "model.norm.weight"),
- ("attn_out", "self_attn.o_proj"),
- ("att_proj", "self_attn.qkv_proj"),
- ("q_norm", "self_attn.q_norm"),
- ("k_norm", "self_attn.k_norm"),
- ("attn_norm", "input_layernorm"),
- ("ff_norm", "post_attention_layernorm"),
- ]
- params_dict = dict(self.named_parameters(remove_duplicate=False))
- embedding_weight = dict()
- projector_weight = dict()
- for name, loaded_weight in weights:
- if "rotary_emb.inv_freq" in name:
- continue
- if self.config.tie_word_embeddings and "lm_head.weight" in name:
- continue
- if "wte.embedding" in name:
- embedding_weight["embedding"] = loaded_weight
- continue
- if "wte.new_embedding" in name:
- embedding_weight["new_embedding"] = loaded_weight
- continue
- if "vision_backbone" in name:
- if name.startswith("model"):
- name = name[len("model."):]
- if 'image_projector' in name:
- if 'w1' in name:
- projector_weight['gate_proj'] = loaded_weight
- elif 'w3' in name:
- projector_weight['up_proj'] = loaded_weight
- elif 'w2' in name:
- projector_weight['down_proj'] = loaded_weight
- else:
- raise ValueError(
- f"Unexpected projector weight: {name}")
- continue
- else:
- if "transformer.blocks" in name:
- name = name.replace("transformer.blocks", "layers")
- if "ff_proj" in name:
- name = name.replace("ff_proj", "mlp.gate_up_proj")
- assert 'weight' in name
- up_weight, gate_weight = loaded_weight.chunk(2, dim=0)
- loaded_weight = torch.cat([gate_weight, up_weight], dim=0)
- elif "ff_out" in name:
- if "layers" in name:
- name = name.replace("ff_out", "mlp.down_proj")
- else:
- # lm head
- name = name.replace("model.transformer.ff_out",
- "lm_head")
- else:
- for (param_name, weight_name) in params_mapping:
- if param_name in name:
- name = name.replace(param_name, weight_name)
- break
- try:
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
- param = params_dict[name]
- except KeyError:
- raise ValueError(f"Unexpected weight: {name}") from None
- weight_loader = getattr(param, "weight_loader",
- default_weight_loader)
- weight_loader(param, loaded_weight)
- gate_up_proj_weight = torch.cat(
- [projector_weight["gate_proj"], projector_weight["up_proj"]],
- dim=0)
- name = "vision_backbone.image_projector.gate_up_proj.weight"
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- weight_loader(param, gate_up_proj_weight)
- down_proj_weight = projector_weight["down_proj"]
- name = "vision_backbone.image_projector.down_proj.weight"
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- weight_loader(param, down_proj_weight)
- embedding_weight = torch.cat(
- [embedding_weight["embedding"], embedding_weight["new_embedding"]],
- dim=0)
- name = "model.embed_tokens.weight"
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- weight_loader(param, embedding_weight)
|