|
@@ -0,0 +1,1294 @@
|
|
|
+import logging
|
|
|
+import math
|
|
|
+import re
|
|
|
+from array import array
|
|
|
+from dataclasses import dataclass
|
|
|
+from functools import lru_cache, partial
|
|
|
+from typing import (Any, Iterable, List, Mapping, Optional, Tuple, TypedDict,
|
|
|
+ Union)
|
|
|
+
|
|
|
+import torch
|
|
|
+from einops import rearrange
|
|
|
+from PIL import Image
|
|
|
+from torch import nn
|
|
|
+from torch.nn import functional as F
|
|
|
+from transformers import PretrainedConfig
|
|
|
+
|
|
|
+import aphrodite.common.envs as envs
|
|
|
+from aphrodite.attention import Attention, AttentionMetadata
|
|
|
+from aphrodite.attention.selector import (_Backend, backend_name_to_enum,
|
|
|
+ get_global_forced_attn_backend)
|
|
|
+from aphrodite.common.config import CacheConfig, MultiModalConfig
|
|
|
+from aphrodite.common.logger import log_once
|
|
|
+from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
|
|
|
+ IntermediateTensors, SequenceData)
|
|
|
+from aphrodite.distributed import (get_pp_group,
|
|
|
+ get_tensor_model_parallel_rank,
|
|
|
+ get_tensor_model_parallel_world_size,
|
|
|
+ split_tensor_along_last_dim,
|
|
|
+ tensor_model_parallel_all_gather)
|
|
|
+from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
|
|
+from aphrodite.modeling.layers.activation import QuickGELU, SiluAndMul
|
|
|
+from aphrodite.modeling.layers.layernorm import RMSNorm
|
|
|
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
|
|
|
+ MergedColumnParallelLinear,
|
|
|
+ QKVParallelLinear,
|
|
|
+ RowParallelLinear)
|
|
|
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
|
|
|
+from aphrodite.modeling.layers.rotary_embedding import get_rope
|
|
|
+from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
|
|
|
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
|
|
|
+ ParallelLMHead, VocabParallelEmbedding)
|
|
|
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
|
|
|
+from aphrodite.modeling.models.interfaces import SupportsMultiModal
|
|
|
+from aphrodite.modeling.models.utils import make_layers
|
|
|
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
|
|
|
+from aphrodite.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
|
|
|
+from aphrodite.platforms import current_platform
|
|
|
+from aphrodite.quantization.base_config import QuantizationConfig
|
|
|
+from aphrodite.transformers_utils.processor import get_processor
|
|
|
+
|
|
|
+log = logging.getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+VIT_LAYERS = [-2, -9]
|
|
|
+NUM_PREFIX_TOKENS = 1
|
|
|
+ADDITIONAL_VOCAB_SIZE = 128
|
|
|
+
|
|
|
+
|
|
|
+class MolmoImageInputs(TypedDict):
|
|
|
+ images: torch.Tensor
|
|
|
+ """Shape:
|
|
|
+ `(batch_size, num_crops, num_patch, patch_dim)`
|
|
|
+ """
|
|
|
+
|
|
|
+ image_input_idx: torch.Tensor
|
|
|
+ """Shape:
|
|
|
+ `(batch_size, num_crops, num_patch)`
|
|
|
+ """
|
|
|
+
|
|
|
+ seq_len: torch.Tensor
|
|
|
+ """Shape:
|
|
|
+ `(batch_size, )`
|
|
|
+ """
|
|
|
+
|
|
|
+ image_masks: Optional[torch.Tensor]
|
|
|
+ """Shape:
|
|
|
+ `(batch_size, num_crops, num_patch)`
|
|
|
+ """
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class VisionBackboneConfig:
|
|
|
+ image_default_input_size: Tuple[int, int] = (336, 336)
|
|
|
+ image_patch_size: int = 14
|
|
|
+ image_pos_patch_size: int = 14
|
|
|
+ image_emb_dim: int = 1024
|
|
|
+ image_num_heads: int = 16
|
|
|
+ image_num_key_value_heads: int = 16
|
|
|
+ image_num_layers: int = 23
|
|
|
+ image_mlp_dim: int = 4096
|
|
|
+ image_mlp_activations: str = "quick_gelu"
|
|
|
+ image_num_pos: int = 577
|
|
|
+ image_norm_eps: float = 1e-5
|
|
|
+
|
|
|
+ def __post_init__(self):
|
|
|
+ self.image_default_input_size = tuple(
|
|
|
+ self.image_default_input_size)
|
|
|
+
|
|
|
+ @property
|
|
|
+ def image_num_patch(self):
|
|
|
+ h, w = self.image_default_input_size
|
|
|
+ return h // self.image_patch_size, w // self.image_patch_size
|
|
|
+
|
|
|
+
|
|
|
+class ViTMLP(nn.Module):
|
|
|
+ """MLP used in Vision Transformer."""
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ config: VisionBackboneConfig,
|
|
|
+ quant_config: Optional[QuantizationConfig] = None,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ self.w1 = ColumnParallelLinear(
|
|
|
+ config.image_emb_dim,
|
|
|
+ config.image_mlp_dim,
|
|
|
+ bias=True,
|
|
|
+ quant_config=quant_config,
|
|
|
+ )
|
|
|
+
|
|
|
+ assert config.image_mlp_activations == "quick_gelu"
|
|
|
+ self.act = QuickGELU()
|
|
|
+ self.w2 = RowParallelLinear(
|
|
|
+ config.image_mlp_dim,
|
|
|
+ config.image_emb_dim,
|
|
|
+ bias=True,
|
|
|
+ quant_config=quant_config,
|
|
|
+ )
|
|
|
+
|
|
|
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
+ x, _ = self.w1(x)
|
|
|
+ x = self.act(x)
|
|
|
+ x, _ = self.w2(x)
|
|
|
+ return x
|
|
|
+
|
|
|
+
|
|
|
+class MultiHeadDotProductAttention(nn.Module):
|
|
|
+ """Multi-head attention used in Vision Transformer."""
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ config: VisionBackboneConfig,
|
|
|
+ use_bias: bool = True,
|
|
|
+ nlayers: int = 1,
|
|
|
+ quant_config: Optional[QuantizationConfig] = None,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ self.hidden_size = config.image_emb_dim
|
|
|
+ self.total_num_heads = config.image_num_heads
|
|
|
+ tp_size = get_tensor_model_parallel_world_size()
|
|
|
+
|
|
|
+ assert self.hidden_size % self.total_num_heads == 0
|
|
|
+ assert self.total_num_heads % tp_size == 0
|
|
|
+
|
|
|
+ self.num_heads = self.total_num_heads // tp_size
|
|
|
+ self.head_dim = self.hidden_size // self.total_num_heads
|
|
|
+
|
|
|
+ self.total_num_kv_heads = config.image_num_key_value_heads
|
|
|
+ if self.total_num_kv_heads >= tp_size:
|
|
|
+ assert self.total_num_kv_heads % tp_size == 0
|
|
|
+ else:
|
|
|
+ assert tp_size % self.total_num_kv_heads == 0
|
|
|
+
|
|
|
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
|
|
+
|
|
|
+ self.wq = ColumnParallelLinear(
|
|
|
+ nlayers * self.hidden_size,
|
|
|
+ self.total_num_heads * self.head_dim,
|
|
|
+ bias=use_bias,
|
|
|
+ quant_config=quant_config,
|
|
|
+ )
|
|
|
+ self.wk = ColumnParallelLinear(
|
|
|
+ nlayers * self.hidden_size,
|
|
|
+ self.total_num_kv_heads * self.head_dim,
|
|
|
+ bias=use_bias,
|
|
|
+ quant_config=quant_config,
|
|
|
+ )
|
|
|
+ self.wv = ColumnParallelLinear(
|
|
|
+ nlayers * self.hidden_size,
|
|
|
+ self.total_num_kv_heads * self.head_dim,
|
|
|
+ bias=use_bias,
|
|
|
+ quant_config=quant_config,
|
|
|
+ )
|
|
|
+ self.wo = RowParallelLinear(
|
|
|
+ self.total_num_heads * self.head_dim,
|
|
|
+ self.hidden_size,
|
|
|
+ bias=use_bias,
|
|
|
+ quant_config=quant_config,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+ selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
|
|
|
+ if selected_backend is None:
|
|
|
+ backend_by_env_var: Optional[str] = envs.APHRODITE_ATTENTION_BACKEND
|
|
|
+ if backend_by_env_var is not None:
|
|
|
+ selected_backend = backend_name_to_enum(backend_by_env_var)
|
|
|
+ if selected_backend is None:
|
|
|
+
|
|
|
+ device_available = current_platform.get_device_capability()[0] >= 8
|
|
|
+ if device_available:
|
|
|
+ from transformers.utils import is_flash_attn_2_available
|
|
|
+ if is_flash_attn_2_available():
|
|
|
+ self._use_flash_attn = True
|
|
|
+ else:
|
|
|
+ log_once(
|
|
|
+ level="WARNING",
|
|
|
+ message=
|
|
|
+ "Current Molmo implementation has a bug with "
|
|
|
+ "`aphrodite-flash-attn` inside vision module, so we use"
|
|
|
+ " xformers backend instead. You can run `pip install "
|
|
|
+ "flash-attn to use flash-attention backend."
|
|
|
+ )
|
|
|
+ self._use_flash_attn = False
|
|
|
+ else:
|
|
|
+ self._use_flash_attn = False
|
|
|
+ else:
|
|
|
+ if selected_backend == _Backend.FLASH_ATTN:
|
|
|
+ self._use_flash_attn = True
|
|
|
+ elif selected_backend == _Backend.XFORMERS:
|
|
|
+ self._use_flash_attn = False
|
|
|
+ else:
|
|
|
+ raise RuntimeError(
|
|
|
+ f"Molmo does not support {selected_backend} backend now.")
|
|
|
+
|
|
|
+ def forward(self,
|
|
|
+ inputs_q: torch.Tensor,
|
|
|
+ inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
+
|
|
|
+ if inputs_kv is not None:
|
|
|
+ inputs_k = inputs_kv
|
|
|
+ inputs_v = inputs_kv
|
|
|
+ else:
|
|
|
+ inputs_k = inputs_q
|
|
|
+ inputs_v = inputs_q
|
|
|
+
|
|
|
+ xq, _ = self.wq(inputs_q)
|
|
|
+ xk, _ = self.wk(inputs_k)
|
|
|
+ xv, _ = self.wv(inputs_v)
|
|
|
+ q_shape = xq.size()[:-1] + (self.num_heads, self.head_dim)
|
|
|
+ kv_shape = xk.size()[:-1] + (self.num_kv_heads, self.head_dim)
|
|
|
+ xq = xq.view(*q_shape)
|
|
|
+ xk = xk.view(*kv_shape)
|
|
|
+ xv = xv.view(*kv_shape)
|
|
|
+
|
|
|
+ if self._use_flash_attn:
|
|
|
+ from flash_attn import flash_attn_func
|
|
|
+ output = flash_attn_func(xq, xk, xv, dropout_p=0.0, causal=False)
|
|
|
+ else:
|
|
|
+ from xformers import ops as xops
|
|
|
+ output = xops.memory_efficient_attention_forward(xq, xk, xv, p=0)
|
|
|
+
|
|
|
+ output = rearrange(output, "b s h d -> b s (h d)").contiguous()
|
|
|
+ output, _ = self.wo(output)
|
|
|
+
|
|
|
+ return output
|
|
|
+
|
|
|
+
|
|
|
+class ResidualAttentionBlock(nn.Module):
|
|
|
+ """Residual attention block used in Vision Transformer."""
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ config: VisionBackboneConfig,
|
|
|
+ quant_config: Optional[QuantizationConfig] = None,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ self.attention = MultiHeadDotProductAttention(
|
|
|
+ config, quant_config=quant_config)
|
|
|
+ self.feed_forward = ViTMLP(config, quant_config)
|
|
|
+ self.attention_norm = nn.LayerNorm(
|
|
|
+ config.image_emb_dim,
|
|
|
+ eps=config.image_norm_eps,
|
|
|
+ )
|
|
|
+ self.ffn_norm = nn.LayerNorm(
|
|
|
+ config.image_emb_dim,
|
|
|
+ eps=config.image_norm_eps,
|
|
|
+ )
|
|
|
+
|
|
|
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
+ x = x + self.attention(self.attention_norm(x))
|
|
|
+ x = x + self.feed_forward(self.ffn_norm(x))
|
|
|
+ return x
|
|
|
+
|
|
|
+
|
|
|
+class BlockCollection(nn.Module):
|
|
|
+ """Collection of residual attention blocks used in Vision Transformer."""
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ config: VisionBackboneConfig,
|
|
|
+ quant_config: Optional[QuantizationConfig] = None,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ self.resblocks = nn.ModuleList([
|
|
|
+ ResidualAttentionBlock(config, quant_config)
|
|
|
+ for _ in range(config.image_num_layers)
|
|
|
+ ])
|
|
|
+
|
|
|
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
|
|
+ hidden_states = []
|
|
|
+ for r in self.resblocks:
|
|
|
+ x = r(x)
|
|
|
+ hidden_states.append(x)
|
|
|
+ return hidden_states
|
|
|
+
|
|
|
+
|
|
|
+def _expand_token(token: torch.Tensor, batch_size: int) -> torch.Tensor:
|
|
|
+ return token.view(1, 1, -1).expand(batch_size, -1, -1)
|
|
|
+
|
|
|
+
|
|
|
+class VisionTransformer(nn.Module):
|
|
|
+ """Vision Transformer used in Vision Backbone."""
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ config: VisionBackboneConfig,
|
|
|
+ quant_config: Optional[QuantizationConfig] = None,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+ scale = config.image_emb_dim**-0.5
|
|
|
+ self.patch_num = config.image_num_patch
|
|
|
+ self.class_embedding = nn.Parameter(
|
|
|
+ torch.randn(config.image_emb_dim) * scale)
|
|
|
+ self.num_prefix_tokens: int = NUM_PREFIX_TOKENS
|
|
|
+ self.positional_embedding = nn.Parameter(
|
|
|
+ torch.randn(config.image_num_pos, config.image_emb_dim) * scale)
|
|
|
+ image_patch_size = config.image_patch_size
|
|
|
+ self.patch_embedding = nn.Linear(
|
|
|
+ image_patch_size * image_patch_size * 3,
|
|
|
+ config.image_emb_dim,
|
|
|
+ bias=False,
|
|
|
+ )
|
|
|
+ self.pre_ln = nn.LayerNorm(config.image_emb_dim,
|
|
|
+ eps=config.image_norm_eps)
|
|
|
+ self.transformer = BlockCollection(config, quant_config)
|
|
|
+
|
|
|
+ def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
|
|
|
+ cls_emb = self.positional_embedding[0:1]
|
|
|
+ pos_emb = self.positional_embedding[1:]
|
|
|
+
|
|
|
+ pos_emb = pos_emb.reshape(
|
|
|
+ (int(math.sqrt(pos_emb.shape[0])),
|
|
|
+ int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1]))
|
|
|
+
|
|
|
+ (patch_num_0, patch_num_1) = patch_num
|
|
|
+
|
|
|
+ if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
|
|
|
+
|
|
|
+ pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
|
|
|
+ pos_emb = F.interpolate(
|
|
|
+ pos_emb,
|
|
|
+ size=(patch_num_0, patch_num_1),
|
|
|
+ mode="bicubic",
|
|
|
+ align_corners=False,
|
|
|
+ antialias=True,
|
|
|
+ )
|
|
|
+ pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)
|
|
|
+
|
|
|
+ pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1])
|
|
|
+ x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]],
|
|
|
+ dim=1).to(x.dtype)
|
|
|
+ return x
|
|
|
+
|
|
|
+ def forward(self,
|
|
|
+ x: torch.Tensor,
|
|
|
+ patch_num: int = None) -> List[torch.Tensor]:
|
|
|
+ """
|
|
|
+ : param x: (batch_size, num_patch, n_pixels)
|
|
|
+ """
|
|
|
+ if patch_num is None:
|
|
|
+ patch_num = self.patch_num
|
|
|
+ B, N, D = x.shape
|
|
|
+
|
|
|
+ x = self.patch_embedding(x)
|
|
|
+
|
|
|
+
|
|
|
+ x = torch.cat(
|
|
|
+ [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x],
|
|
|
+ dim=1)
|
|
|
+ x = self.add_pos_emb(x, patch_num)
|
|
|
+
|
|
|
+ x = self.pre_ln(x)
|
|
|
+
|
|
|
+ hidden_states = self.transformer(x)
|
|
|
+ return hidden_states
|
|
|
+
|
|
|
+
|
|
|
+class MolmoAttention(nn.Module):
|
|
|
+ """Molmo's LLM attention."""
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ config: PretrainedConfig,
|
|
|
+ cache_config: Optional[CacheConfig] = None,
|
|
|
+ quant_config: Optional[QuantizationConfig] = None,
|
|
|
+ ) -> None:
|
|
|
+ super().__init__()
|
|
|
+ self.hidden_size = config.hidden_size
|
|
|
+ self.tp_size = get_tensor_model_parallel_world_size()
|
|
|
+ self.total_num_heads = config.num_attention_heads
|
|
|
+
|
|
|
+ assert self.hidden_size % self.total_num_heads == 0
|
|
|
+ assert self.total_num_heads % self.tp_size == 0
|
|
|
+
|
|
|
+ self.num_heads = self.total_num_heads // self.tp_size
|
|
|
+ self.total_num_kv_heads = config.num_key_value_heads \
|
|
|
+ or self.total_num_heads
|
|
|
+ if self.total_num_kv_heads >= self.tp_size:
|
|
|
+ assert self.total_num_kv_heads % self.tp_size == 0
|
|
|
+ else:
|
|
|
+ assert self.tp_size % self.total_num_kv_heads == 0
|
|
|
+
|
|
|
+ self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
|
|
|
+ self.head_dim = self.hidden_size // self.total_num_heads
|
|
|
+ self.q_size = self.num_heads * self.head_dim
|
|
|
+ self.kv_size = self.num_kv_heads * self.head_dim
|
|
|
+ self.max_position_embeddings = config.max_position_embeddings
|
|
|
+ self.rope_theta = config.rope_theta
|
|
|
+
|
|
|
+
|
|
|
+ self.qkv_proj = QKVParallelLinear(
|
|
|
+ self.hidden_size,
|
|
|
+ self.head_dim,
|
|
|
+ self.total_num_heads,
|
|
|
+ self.total_num_kv_heads,
|
|
|
+ bias=config.qkv_bias,
|
|
|
+ quant_config=quant_config,
|
|
|
+ )
|
|
|
+
|
|
|
+ self.tp_rank: Optional[int] = None
|
|
|
+ self.k_norm: Optional[nn.Module] = None
|
|
|
+ self.q_norm: Optional[nn.Module] = None
|
|
|
+ if config.attention_layer_norm:
|
|
|
+ self.tp_rank = get_tensor_model_parallel_rank()
|
|
|
+ self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim,
|
|
|
+ eps=config.layer_norm_eps)
|
|
|
+ self.q_norm = RMSNorm(config.hidden_size,
|
|
|
+ eps=config.layer_norm_eps)
|
|
|
+
|
|
|
+
|
|
|
+ self.rotary_emb = get_rope(
|
|
|
+ self.head_dim,
|
|
|
+ rotary_dim=self.head_dim,
|
|
|
+ max_position=self.max_position_embeddings,
|
|
|
+ base=self.rope_theta,
|
|
|
+ )
|
|
|
+ self.scaling = self.head_dim**-0.5
|
|
|
+ self.attn = Attention(self.num_heads,
|
|
|
+ self.head_dim,
|
|
|
+ self.scaling,
|
|
|
+ num_kv_heads=self.num_kv_heads,
|
|
|
+ cache_config=cache_config,
|
|
|
+ quant_config=quant_config)
|
|
|
+
|
|
|
+
|
|
|
+ self.o_proj = RowParallelLinear(
|
|
|
+ self.total_num_heads * self.head_dim,
|
|
|
+ self.hidden_size,
|
|
|
+ bias=False,
|
|
|
+ quant_config=quant_config,
|
|
|
+ )
|
|
|
+
|
|
|
+ def _apply_qk_norm(self, q: torch.Tensor,
|
|
|
+ k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
+ if self.tp_size > 1:
|
|
|
+ q = tensor_model_parallel_all_gather(q.contiguous())
|
|
|
+ k = tensor_model_parallel_all_gather(k.contiguous())
|
|
|
+ q = self.q_norm.forward_native(q)
|
|
|
+ k = self.k_norm.forward_native(k)
|
|
|
+ if self.tp_size > 1:
|
|
|
+ splitter = partial(split_tensor_along_last_dim,
|
|
|
+ num_partitions=self.tp_size)
|
|
|
+ q = splitter(q)[self.tp_rank]
|
|
|
+ k = splitter(k)[self.tp_rank]
|
|
|
+ return q, k
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ positions: torch.Tensor,
|
|
|
+ hidden_states: torch.Tensor,
|
|
|
+ kv_cache: torch.Tensor,
|
|
|
+ attn_metadata: AttentionMetadata,
|
|
|
+ ) -> torch.Tensor:
|
|
|
+ qkv, _ = self.qkv_proj(hidden_states)
|
|
|
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
|
+ if self.q_norm is not None and self.k_norm is not None:
|
|
|
+ q, k = self._apply_qk_norm(q, k)
|
|
|
+ q, k = self.rotary_emb(positions, q, k)
|
|
|
+ attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
|
|
+ output, _ = self.o_proj(attn_output)
|
|
|
+ return output
|
|
|
+
|
|
|
+
|
|
|
+class MolmoMLP(nn.Module):
|
|
|
+ """Molmo's LLM mlp."""
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ config: PretrainedConfig,
|
|
|
+ input_dim: Optional[int] = None,
|
|
|
+ quant_config: Optional[QuantizationConfig] = None,
|
|
|
+ ) -> None:
|
|
|
+ super().__init__()
|
|
|
+ self.hidden_size = config.hidden_size
|
|
|
+ self.intermediate_size = config.intermediate_size // 2
|
|
|
+
|
|
|
+
|
|
|
+ self.gate_up_proj = MergedColumnParallelLinear(
|
|
|
+ input_dim or self.hidden_size,
|
|
|
+ [self.intermediate_size] * 2,
|
|
|
+ bias=False,
|
|
|
+ quant_config=quant_config,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+ self.act_fn = SiluAndMul()
|
|
|
+
|
|
|
+
|
|
|
+ self.down_proj = RowParallelLinear(
|
|
|
+ self.intermediate_size,
|
|
|
+ self.hidden_size,
|
|
|
+ bias=False,
|
|
|
+ quant_config=quant_config,
|
|
|
+ )
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ x: torch.Tensor,
|
|
|
+ ) -> torch.Tensor:
|
|
|
+ gate_up, _ = self.gate_up_proj(x)
|
|
|
+ x = self.act_fn(gate_up)
|
|
|
+ x, _ = self.down_proj(x)
|
|
|
+ return x
|
|
|
+
|
|
|
+
|
|
|
+class MolmoDecoderLayer(nn.Module):
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ config: PretrainedConfig,
|
|
|
+ cache_config: Optional[CacheConfig] = None,
|
|
|
+ quant_config: Optional[QuantizationConfig] = None,
|
|
|
+ ) -> None:
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ self.self_attn = MolmoAttention(config, cache_config, quant_config)
|
|
|
+
|
|
|
+
|
|
|
+ self.mlp = MolmoMLP(config, quant_config=quant_config)
|
|
|
+
|
|
|
+
|
|
|
+ assert config.layer_norm_type == "rms"
|
|
|
+ self.input_layernorm = RMSNorm(config.hidden_size,
|
|
|
+ eps=config.layer_norm_eps)
|
|
|
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
|
|
+ eps=config.layer_norm_eps)
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ positions: torch.Tensor,
|
|
|
+ hidden_states: torch.Tensor,
|
|
|
+ kv_cache: torch.Tensor,
|
|
|
+ attn_metadata: AttentionMetadata,
|
|
|
+ residual: Optional[torch.Tensor],
|
|
|
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
|
+
|
|
|
+ if residual is None:
|
|
|
+ residual = hidden_states
|
|
|
+ hidden_states = self.input_layernorm(hidden_states)
|
|
|
+ else:
|
|
|
+ hidden_states, residual = self.input_layernorm(
|
|
|
+ hidden_states, residual)
|
|
|
+ hidden_states = self.self_attn(
|
|
|
+ positions=positions,
|
|
|
+ hidden_states=hidden_states,
|
|
|
+ kv_cache=kv_cache,
|
|
|
+ attn_metadata=attn_metadata,
|
|
|
+ )
|
|
|
+
|
|
|
+ hidden_states, residual = self.post_attention_layernorm(
|
|
|
+ hidden_states, residual)
|
|
|
+ hidden_states = self.mlp(hidden_states)
|
|
|
+ return hidden_states, residual
|
|
|
+
|
|
|
+
|
|
|
+class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ positions: torch.Tensor,
|
|
|
+ hidden_states: torch.Tensor,
|
|
|
+ kv_cache: torch.Tensor,
|
|
|
+ attn_metadata: AttentionMetadata,
|
|
|
+ residual: Optional[torch.Tensor],
|
|
|
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
|
+
|
|
|
+ residual = hidden_states
|
|
|
+ hidden_states = self.self_attn(
|
|
|
+ positions=positions,
|
|
|
+ hidden_states=hidden_states,
|
|
|
+ kv_cache=kv_cache,
|
|
|
+ attn_metadata=attn_metadata,
|
|
|
+ )
|
|
|
+
|
|
|
+ hidden_states = self.input_layernorm(hidden_states)
|
|
|
+ hidden_states = hidden_states + residual
|
|
|
+ residual = hidden_states
|
|
|
+
|
|
|
+ hidden_states = self.mlp(hidden_states)
|
|
|
+ hidden_states = self.post_attention_layernorm(hidden_states)
|
|
|
+ hidden_states = hidden_states + residual
|
|
|
+ residual = None
|
|
|
+ return hidden_states, residual
|
|
|
+
|
|
|
+
|
|
|
+class MolmoVisionBackbone(nn.Module):
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ config: PretrainedConfig,
|
|
|
+ vision_config: VisionBackboneConfig,
|
|
|
+ quant_config: Optional[QuantizationConfig] = None,
|
|
|
+ ) -> None:
|
|
|
+ super().__init__()
|
|
|
+ self.vit_layers = VIT_LAYERS
|
|
|
+ self.image_num_patch = vision_config.image_num_patch
|
|
|
+ self.llm_patches_per_crop = (
|
|
|
+ (self.image_num_patch[0] + 1) // 2,
|
|
|
+ (self.image_num_patch[1] + 1) // 2,
|
|
|
+ )
|
|
|
+ self.image_vit = VisionTransformer(vision_config,
|
|
|
+ quant_config=quant_config)
|
|
|
+ self.num_prefix_tokens = self.image_vit.num_prefix_tokens
|
|
|
+ assert self.num_prefix_tokens in {
|
|
|
+ 0, 1
|
|
|
+ }, "Only 0 or 1 prefix tokens are supported"
|
|
|
+ self.image_pooling_2d = MultiHeadDotProductAttention(
|
|
|
+ vision_config,
|
|
|
+ nlayers=len(self.vit_layers),
|
|
|
+ quant_config=quant_config)
|
|
|
+ self.image_projector = MolmoMLP(
|
|
|
+ config,
|
|
|
+ input_dim=vision_config.image_emb_dim,
|
|
|
+ quant_config=quant_config,
|
|
|
+ )
|
|
|
+
|
|
|
+ image_dim = vision_config.image_emb_dim * len(self.vit_layers)
|
|
|
+ self.pad_embed = nn.Parameter(torch.zeros((2, image_dim)))
|
|
|
+
|
|
|
+ @property
|
|
|
+ def dtype(self) -> torch.dtype:
|
|
|
+ return self.image_vit.patch_embedding.weight.dtype
|
|
|
+
|
|
|
+ @property
|
|
|
+ def device(self) -> torch.device:
|
|
|
+ return self.image_vit.patch_embedding.weight.device
|
|
|
+
|
|
|
+ def encode_image(self, images: torch.Tensor) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ : param images: (batch_size, num_crops, num_patch, n_pixels)
|
|
|
+ """
|
|
|
+ B, T, N, D = images.shape
|
|
|
+
|
|
|
+ mask = ~torch.all(
|
|
|
+ images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True)
|
|
|
+
|
|
|
+ images = images.view(B * T, N, D)
|
|
|
+ image_features = self.image_vit(images)
|
|
|
+
|
|
|
+ if self.vit_layers is not None:
|
|
|
+ features = []
|
|
|
+ for layer in self.vit_layers:
|
|
|
+ features.append(image_features[layer])
|
|
|
+ image_features = torch.cat(features, dim=-1)
|
|
|
+ else:
|
|
|
+ image_features = image_features[-1]
|
|
|
+
|
|
|
+ if self.num_prefix_tokens > 0:
|
|
|
+ image_features = image_features[:, 1:]
|
|
|
+
|
|
|
+ image_features = image_features * mask
|
|
|
+ image_features = image_features.view(B, T, N, -1)
|
|
|
+
|
|
|
+ return image_features
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self, images: torch.Tensor, image_masks: torch.Tensor
|
|
|
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
|
+
|
|
|
+
|
|
|
+ batch_size, num_image = images.shape[:2]
|
|
|
+ images = images.to(device=self.device, dtype=self.dtype)
|
|
|
+ image_features = self.encode_image(images)
|
|
|
+
|
|
|
+ og_dtype = image_features.dtype
|
|
|
+ assert image_masks is not None
|
|
|
+ pad_embed = self.pad_embed[:, None, None, None, :]
|
|
|
+ all_pad = image_masks == 0
|
|
|
+ partial_pad = torch.logical_and(
|
|
|
+ image_masks < 1,
|
|
|
+ torch.logical_not(all_pad)).to(dtype=torch.float32)
|
|
|
+ all_pad = all_pad.to(dtype=torch.float32)
|
|
|
+ image_features = image_features + pad_embed[0] * torch.unsqueeze(
|
|
|
+ all_pad, -1)
|
|
|
+ image_features = image_features + pad_embed[1] * torch.unsqueeze(
|
|
|
+ partial_pad, -1)
|
|
|
+
|
|
|
+ image_features = image_features.to(og_dtype)
|
|
|
+
|
|
|
+ image_features = image_features.reshape(
|
|
|
+ (batch_size, num_image) + self.image_num_patch + (-1, ), )
|
|
|
+
|
|
|
+ if self.image_num_patch[0] % 2 == 1:
|
|
|
+
|
|
|
+ image_features = F.pad(
|
|
|
+ image_features,
|
|
|
+ (0, 0, 0, 1, 0, 1, 0, 0, 0, 0),
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+ image_features = rearrange(
|
|
|
+ image_features,
|
|
|
+ 'b n (h dh) (w dw) c -> (b n h w) (dh dw) c',
|
|
|
+ dh=2,
|
|
|
+ dw=2,
|
|
|
+ )
|
|
|
+
|
|
|
+ query = image_features.mean(-2, keepdim=True)
|
|
|
+ image_features = self.image_pooling_2d(query, image_features)
|
|
|
+
|
|
|
+ h, w = self.llm_patches_per_crop
|
|
|
+ image_features = image_features.view(batch_size, num_image, h * w, -1)
|
|
|
+
|
|
|
+ image_features = self.image_projector(image_features)
|
|
|
+
|
|
|
+
|
|
|
+ return image_features
|
|
|
+
|
|
|
+
|
|
|
+class MolmoModel(nn.Module):
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ config: PretrainedConfig,
|
|
|
+ cache_config: Optional[CacheConfig] = None,
|
|
|
+ quant_config: Optional[QuantizationConfig] = None,
|
|
|
+ prefix: str = "",
|
|
|
+ ) -> None:
|
|
|
+ super().__init__()
|
|
|
+ self.config = config
|
|
|
+
|
|
|
+ self.embedding_size = config.embedding_size or config.vocab_size
|
|
|
+ self.embedding_size += ADDITIONAL_VOCAB_SIZE
|
|
|
+ self.embed_tokens = VocabParallelEmbedding(
|
|
|
+ self.embedding_size,
|
|
|
+ config.hidden_size,
|
|
|
+ quant_config=quant_config,
|
|
|
+ )
|
|
|
+
|
|
|
+ decoder_layer = MolmoDecoderNormAfterLayer if config.norm_after \
|
|
|
+ else MolmoDecoderLayer
|
|
|
+ self.start_layer, self.end_layer, self.layers = make_layers(
|
|
|
+ config.num_hidden_layers,
|
|
|
+ lambda prefix: decoder_layer(config, cache_config, quant_config),
|
|
|
+ prefix=f"{prefix}.layers",
|
|
|
+ )
|
|
|
+
|
|
|
+ assert config.layer_norm_type == "rms"
|
|
|
+ self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps)
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ input_ids: torch.Tensor,
|
|
|
+ positions: torch.Tensor,
|
|
|
+ kv_caches: List[torch.Tensor],
|
|
|
+ attn_metadata: AttentionMetadata,
|
|
|
+ intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
|
+ inputs_embeds: Optional[torch.Tensor] = None,
|
|
|
+ ) -> torch.Tensor:
|
|
|
+ if get_pp_group().is_first_rank:
|
|
|
+ if inputs_embeds is not None:
|
|
|
+ hidden_states = inputs_embeds
|
|
|
+ else:
|
|
|
+ hidden_states = self.embed_tokens(input_ids)
|
|
|
+ residual = None
|
|
|
+ else:
|
|
|
+ assert intermediate_tensors is not None
|
|
|
+ hidden_states = intermediate_tensors["hidden_states"]
|
|
|
+ residual = intermediate_tensors["residual"]
|
|
|
+
|
|
|
+
|
|
|
+ for i in range(self.start_layer, self.end_layer):
|
|
|
+ layer = self.layers[i]
|
|
|
+ hidden_states, residual = layer(
|
|
|
+ positions,
|
|
|
+ hidden_states,
|
|
|
+ kv_caches[i - self.start_layer],
|
|
|
+ attn_metadata,
|
|
|
+ residual,
|
|
|
+ )
|
|
|
+ if not get_pp_group().is_last_rank:
|
|
|
+ return IntermediateTensors({
|
|
|
+ "hidden_states": hidden_states,
|
|
|
+ "residual": residual
|
|
|
+ })
|
|
|
+ if residual is not None:
|
|
|
+ hidden_states, _ = self.norm(hidden_states, residual)
|
|
|
+ else:
|
|
|
+ hidden_states = self.norm(hidden_states)
|
|
|
+ return hidden_states
|
|
|
+
|
|
|
+
|
|
|
+cached_get_processor = lru_cache(get_processor)
|
|
|
+
|
|
|
+
|
|
|
+def get_num_patches(num_tiles: int, crop_patches: int, left_margin: int,
|
|
|
+ right_margin: int, pooling_size: int) -> int:
|
|
|
+ crop_window_patches = crop_patches - (left_margin + right_margin)
|
|
|
+ if num_tiles > 1:
|
|
|
+ left_crop_window_patches = (crop_window_patches + left_margin +
|
|
|
+ pooling_size -
|
|
|
+ 1) // pooling_size * pooling_size
|
|
|
+ middle_crop_window_patches = (crop_window_patches + pooling_size -
|
|
|
+ 1) // pooling_size * pooling_size
|
|
|
+ right_crop_window_patches = (crop_window_patches + right_margin +
|
|
|
+ pooling_size -
|
|
|
+ 1) // pooling_size * pooling_size
|
|
|
+ return left_crop_window_patches + (
|
|
|
+ num_tiles -
|
|
|
+ 2) * middle_crop_window_patches + right_crop_window_patches
|
|
|
+ else:
|
|
|
+ single_crop_window_patches = (crop_patches + pooling_size -
|
|
|
+ 1) // pooling_size * pooling_size
|
|
|
+ return single_crop_window_patches
|
|
|
+
|
|
|
+
|
|
|
+def get_tokens(tiling_h: int, tiling_w: int, crop_patches: int,
|
|
|
+ left_margin: int, right_margin: int, pooling_size: int) -> int:
|
|
|
+ h = get_num_patches(tiling_h, crop_patches, left_margin, right_margin,
|
|
|
+ pooling_size)
|
|
|
+ w = get_num_patches(tiling_w, crop_patches, left_margin, right_margin,
|
|
|
+ pooling_size)
|
|
|
+ per_row = w // pooling_size + 1
|
|
|
+ joint = per_row * (h // pooling_size) + 2
|
|
|
+ image_token_length = (crop_patches + pooling_size - 1) // pooling_size
|
|
|
+ resize = (image_token_length + 1) * image_token_length + 2
|
|
|
+ return resize + joint
|
|
|
+
|
|
|
+
|
|
|
+def get_max_tokens(max_crops: int, crop_patches: int, left_margin: int,
|
|
|
+ right_margin: int, pooling_size: int) -> int:
|
|
|
+ tilings = []
|
|
|
+ for i in range(1, max_crops + 1):
|
|
|
+ for j in range(1, max_crops + 1):
|
|
|
+ if i * j <= max_crops:
|
|
|
+ tilings.append((i, j))
|
|
|
+ tokens = [
|
|
|
+ get_tokens(tilings[i][0], tilings[i][1], crop_patches, left_margin,
|
|
|
+ right_margin, pooling_size) for i in range(len(tilings))
|
|
|
+ ]
|
|
|
+ return max(tokens)
|
|
|
+
|
|
|
+
|
|
|
+def get_max_molmo_image_tokens(ctx: InputContext) -> int:
|
|
|
+ processor = cached_get_processor(ctx.model_config.model,
|
|
|
+ trust_remote_code=True,
|
|
|
+ revision=ctx.model_config.code_revision)
|
|
|
+ image_processor = processor.image_processor
|
|
|
+ max_llm_image_tokens = get_max_tokens(
|
|
|
+ image_processor.max_crops,
|
|
|
+ image_processor.base_image_input_size[0] //
|
|
|
+ image_processor.image_patch_size,
|
|
|
+ image_processor.overlap_margins[0],
|
|
|
+ image_processor.overlap_margins[1],
|
|
|
+ 2,
|
|
|
+ )
|
|
|
+ return max_llm_image_tokens
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+def image_input_mapper_for_molmo(
|
|
|
+ ctx: InputContext,
|
|
|
+ data: object,
|
|
|
+):
|
|
|
+ return MultiModalInputs(data)
|
|
|
+
|
|
|
+
|
|
|
+def dummy_data_for_molmo(ctx: InputContext, seq_len: int,
|
|
|
+ mm_counts: Mapping[str, int]):
|
|
|
+ processor = cached_get_processor(ctx.model_config.model,
|
|
|
+ trust_remote_code=True,
|
|
|
+ revision=ctx.model_config.code_revision)
|
|
|
+ image_processor = processor.image_processor
|
|
|
+
|
|
|
+ base_image_input_d = image_processor.image_patch_size
|
|
|
+ left_margin, right_margin = image_processor.overlap_margins
|
|
|
+ max_crops = image_processor.max_crops
|
|
|
+
|
|
|
+
|
|
|
+ max_llm_image_tokens = get_max_molmo_image_tokens(ctx)
|
|
|
+ if seq_len - max_llm_image_tokens - 1 < 0:
|
|
|
+ raise RuntimeError(
|
|
|
+ f"Molmo cannot process {max_crops} crops in a prompt, "
|
|
|
+ "please increase max_model_len or reduce number of crops")
|
|
|
+
|
|
|
+
|
|
|
+ tiling = (max_crops, 1)
|
|
|
+ total_margin_pixels = base_image_input_d * (right_margin + left_margin)
|
|
|
+ crop_patches = image_processor.base_image_input_size[
|
|
|
+ 0] // base_image_input_d
|
|
|
+ crop_window_patches = crop_patches - (right_margin + left_margin)
|
|
|
+ crop_window_size = crop_window_patches * base_image_input_d
|
|
|
+
|
|
|
+ h = crop_window_size * tiling[0] + total_margin_pixels
|
|
|
+ w = crop_window_size * tiling[1] + total_margin_pixels
|
|
|
+
|
|
|
+ dummy_image = Image.new("RGB", (w, h), color="red")
|
|
|
+
|
|
|
+ out = processor.process("dummy prompt", dummy_image)
|
|
|
+
|
|
|
+ token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
|
|
|
+ out["input_ids"][:1 + max_llm_image_tokens])
|
|
|
+ token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
|
|
|
+ [0]) * (seq_len - max_llm_image_tokens - 1)
|
|
|
+ dummy_seqdata = SequenceData(token_ids)
|
|
|
+ dummy_imgdata = {
|
|
|
+ "images": out["images"],
|
|
|
+ "image_input_idx": out["image_input_idx"],
|
|
|
+ }
|
|
|
+ if "image_masks" in out:
|
|
|
+ dummy_imgdata["image_masks"] = out["image_masks"]
|
|
|
+ dummy_imgdata["seq_len"] = torch.tensor(seq_len, dtype=torch.long)
|
|
|
+ return dummy_seqdata, {"image": dummy_imgdata}
|
|
|
+
|
|
|
+
|
|
|
+def pad_images(
|
|
|
+ max_total_crops: int,
|
|
|
+ images: torch.Tensor,
|
|
|
+ image_input_idx: torch.Tensor,
|
|
|
+ image_masks: Optional[torch.Tensor] = None,
|
|
|
+):
|
|
|
+ n = max_total_crops - images.shape[0]
|
|
|
+ images = F.pad(images, (0, 0, 0, 0, 0, n), value=-1)
|
|
|
+ image_input_idx = F.pad(image_input_idx, (0, 0, 0, n), value=-1)
|
|
|
+ if image_masks is not None:
|
|
|
+ image_masks = F.pad(image_masks, (0, 0, 0, n), value=-1)
|
|
|
+ return images, image_input_idx, image_masks
|
|
|
+
|
|
|
+
|
|
|
+def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs):
|
|
|
+ prompt = llm_inputs["prompt"]
|
|
|
+ multi_modal_data = llm_inputs.get("multi_modal_data")
|
|
|
+ image = multi_modal_data.get("image")
|
|
|
+ processor = cached_get_processor(ctx.model_config.model,
|
|
|
+ trust_remote_code=True,
|
|
|
+ revision=ctx.model_config.code_revision)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ if prompt is not None and re.match(r"^User:[\s\S]*?(Assistant:)*$",
|
|
|
+ prompt):
|
|
|
+ out = processor.process(prompt, image, message_format="none")
|
|
|
+ elif prompt is not None:
|
|
|
+ out = processor.process(prompt, image)
|
|
|
+ else:
|
|
|
+ out = processor.process(None,
|
|
|
+ image,
|
|
|
+ tokens=llm_inputs["prompt_token_ids"])
|
|
|
+
|
|
|
+ image_processor = processor.image_processor
|
|
|
+ max_total_crops = 1 + image_processor.max_crops
|
|
|
+ if image is not None:
|
|
|
+ images, image_input_idx, image_masks = pad_images(
|
|
|
+ max_total_crops,
|
|
|
+ out["images"],
|
|
|
+ out["image_input_idx"],
|
|
|
+ out.get("image_masks"),
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ base_image_input_size = image_processor.base_image_input_size
|
|
|
+ image_patch_size = image_processor.image_patch_size
|
|
|
+ image_num_patch = (
|
|
|
+ base_image_input_size[0] // image_patch_size,
|
|
|
+ base_image_input_size[1] // image_patch_size,
|
|
|
+ )
|
|
|
+ n_pixels = image_patch_size * image_patch_size * 3
|
|
|
+ n_patches = image_num_patch[0] * image_num_patch[1]
|
|
|
+
|
|
|
+ image_length_w = image_processor.image_token_length_w
|
|
|
+ image_length_h = image_processor.image_token_length_h
|
|
|
+ tokens_per_image = image_length_w * image_length_h
|
|
|
+ images = torch.full(
|
|
|
+ (max_total_crops, n_patches, n_pixels),
|
|
|
+ -1,
|
|
|
+ dtype=torch.float32,
|
|
|
+ )
|
|
|
+ image_input_idx = torch.full(
|
|
|
+ (max_total_crops, tokens_per_image),
|
|
|
+ -1,
|
|
|
+ dtype=torch.int32,
|
|
|
+ )
|
|
|
+ if image_processor.image_padding_mask:
|
|
|
+ image_masks = torch.full(
|
|
|
+ (max_total_crops, n_patches),
|
|
|
+ -1,
|
|
|
+ dtype=torch.float32,
|
|
|
+ )
|
|
|
+
|
|
|
+ image_data = dict(
|
|
|
+ images=images,
|
|
|
+ image_input_idx=image_input_idx,
|
|
|
+ )
|
|
|
+ if image_masks is not None:
|
|
|
+ image_data["image_masks"] = image_masks
|
|
|
+
|
|
|
+ image_data["seq_len"] = torch.tensor(len(out["input_ids"]),
|
|
|
+ dtype=torch.long)
|
|
|
+
|
|
|
+ multi_modal_data = dict(image=image_data)
|
|
|
+
|
|
|
+ return LLMInputs(
|
|
|
+ prompt_token_ids=out["input_ids"],
|
|
|
+ prompt=llm_inputs["prompt"],
|
|
|
+ multi_modal_data=multi_modal_data,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@MULTIMODAL_REGISTRY.register_image_input_mapper(image_input_mapper_for_molmo)
|
|
|
+@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens)
|
|
|
+@INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo)
|
|
|
+@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo)
|
|
|
+class MolmoForCausalLM(nn.Module, SupportsMultiModal):
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ config: PretrainedConfig,
|
|
|
+ multimodal_config: Optional[MultiModalConfig] = None,
|
|
|
+ cache_config: Optional[CacheConfig] = None,
|
|
|
+ quant_config: Optional[Mapping[str, Any]] = None,
|
|
|
+ ) -> None:
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ self.config = config
|
|
|
+ self.multimodal_config = multimodal_config
|
|
|
+
|
|
|
+ vision_config = VisionBackboneConfig()
|
|
|
+ self.vision_backbone = MolmoVisionBackbone(config, vision_config,
|
|
|
+ quant_config)
|
|
|
+ self.model = MolmoModel(config, cache_config, quant_config)
|
|
|
+
|
|
|
+ if self.config.weight_tying:
|
|
|
+ self.lm_head = self.model.transformer.wte
|
|
|
+ else:
|
|
|
+ self.lm_head = ParallelLMHead(
|
|
|
+ config.embedding_size or config.vocab_size,
|
|
|
+ config.hidden_size,
|
|
|
+ quant_config=quant_config,
|
|
|
+ )
|
|
|
+
|
|
|
+ self.logits_processor = LogitsProcessor(config.embedding_size
|
|
|
+ or config.vocab_size)
|
|
|
+ self.sampler = Sampler()
|
|
|
+
|
|
|
+ def _parse_and_validate_image_input(
|
|
|
+ self,
|
|
|
+ **kwargs: object,
|
|
|
+ ) -> Optional[MolmoImageInputs]:
|
|
|
+ images = kwargs.pop("images", None)
|
|
|
+ image_masks = kwargs.pop("image_masks", None)
|
|
|
+ if images is None:
|
|
|
+ return None
|
|
|
+
|
|
|
+ image_input_idx = kwargs.pop("image_input_idx", None)
|
|
|
+ seq_len = kwargs.pop("seq_len", None)
|
|
|
+ if image_input_idx is None:
|
|
|
+ raise ValueError("image_input_idx is required for Molmo model.")
|
|
|
+ if seq_len is None:
|
|
|
+ raise ValueError("seq_len is required for Molmo model.")
|
|
|
+ if not isinstance(seq_len, torch.Tensor):
|
|
|
+ seq_len = torch.tensor(seq_len)
|
|
|
+
|
|
|
+ return MolmoImageInputs(
|
|
|
+ images=images,
|
|
|
+ image_input_idx=image_input_idx,
|
|
|
+ seq_len=seq_len,
|
|
|
+ image_masks=image_masks,
|
|
|
+ )
|
|
|
+
|
|
|
+ def _process_image_input(
|
|
|
+ self,
|
|
|
+ image_input: MolmoImageInputs,
|
|
|
+ ) -> torch.Tensor:
|
|
|
+
|
|
|
+ image_features = self.vision_backbone(
|
|
|
+ images=image_input["images"],
|
|
|
+ image_masks=image_input["image_masks"],
|
|
|
+ )
|
|
|
+
|
|
|
+ return image_features
|
|
|
+
|
|
|
+ def _merge_multimodal_embeddings(
|
|
|
+ self,
|
|
|
+ inputs_embeds: torch.Tensor,
|
|
|
+ image_features: torch.Tensor,
|
|
|
+ image_input_idx: torch.Tensor,
|
|
|
+ seq_len: Union[torch.Tensor, List[torch.Tensor]],
|
|
|
+ ) -> torch.Tensor:
|
|
|
+ batch_size, num_image, num_patch = image_features.shape[:3]
|
|
|
+ assert image_input_idx.shape == (batch_size, num_image, num_patch)
|
|
|
+
|
|
|
+ image_features = image_features.to(inputs_embeds.device)
|
|
|
+ seq_len = seq_len.to(inputs_embeds.device)
|
|
|
+
|
|
|
+
|
|
|
+ image_features = image_features.view(batch_size, num_image * num_patch,
|
|
|
+ -1)
|
|
|
+ image_input_idx = image_input_idx.view(batch_size,
|
|
|
+ num_image * num_patch)
|
|
|
+
|
|
|
+ valid = image_input_idx >= 0
|
|
|
+ image_features = image_features * valid[:, :, None].to(
|
|
|
+ image_features.dtype)
|
|
|
+ image_features = image_features.view(
|
|
|
+ batch_size * num_image * num_patch, -1).contiguous()
|
|
|
+
|
|
|
+ image_input_idx = image_input_idx * valid.to(image_input_idx.dtype)
|
|
|
+ offset = torch.cat(
|
|
|
+ [seq_len.new_zeros(
|
|
|
+ (1)), seq_len.cumsum(dim=0)[:-1]], dim=0)[:, None]
|
|
|
+ image_input_idx = image_input_idx + offset.to(image_input_idx.dtype)
|
|
|
+ image_input_idx = image_input_idx.flatten()[:, None]
|
|
|
+ mat = image_input_idx == torch.arange(
|
|
|
+ seq_len.sum().item(), device=inputs_embeds.device)[None, :]
|
|
|
+ mat = mat.to(image_features.dtype)
|
|
|
+
|
|
|
+ inputs_embeds = inputs_embeds + torch.einsum('nd,nm->md',
|
|
|
+ image_features, mat)
|
|
|
+
|
|
|
+ return inputs_embeds
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ input_ids: torch.LongTensor,
|
|
|
+ positions: torch.LongTensor,
|
|
|
+ kv_caches: List[torch.Tensor],
|
|
|
+ attn_metadata: AttentionMetadata,
|
|
|
+ **kwargs: object,
|
|
|
+ ) -> SamplerOutput:
|
|
|
+
|
|
|
+ image_input = self._parse_and_validate_image_input(**kwargs)
|
|
|
+
|
|
|
+ if image_input is not None:
|
|
|
+ inputs_embeds = self.model.embed_tokens(input_ids)
|
|
|
+ image_features = self._process_image_input(image_input)
|
|
|
+
|
|
|
+ inputs_embeds = self._merge_multimodal_embeddings(
|
|
|
+ inputs_embeds,
|
|
|
+ image_features,
|
|
|
+ image_input["image_input_idx"],
|
|
|
+ image_input["seq_len"],
|
|
|
+ )
|
|
|
+
|
|
|
+ input_ids = None
|
|
|
+ else:
|
|
|
+ inputs_embeds = None
|
|
|
+
|
|
|
+ hidden_states = self.model(
|
|
|
+ input_ids=input_ids,
|
|
|
+ positions=positions,
|
|
|
+ kv_caches=kv_caches,
|
|
|
+ attn_metadata=attn_metadata,
|
|
|
+ inputs_embeds=inputs_embeds,
|
|
|
+ )
|
|
|
+
|
|
|
+ return hidden_states
|
|
|
+
|
|
|
+ def compute_logits(self, hidden_states: torch.Tensor,
|
|
|
+ sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
|
|
+ logits = self.logits_processor(self.lm_head, hidden_states,
|
|
|
+ sampling_metadata)
|
|
|
+ return logits
|
|
|
+
|
|
|
+ def sample(
|
|
|
+ self,
|
|
|
+ logits: torch.Tensor,
|
|
|
+ sampling_metadata: SamplingMetadata,
|
|
|
+ ) -> Optional[SamplerOutput]:
|
|
|
+ next_tokens = self.sampler(logits, sampling_metadata)
|
|
|
+ return next_tokens
|
|
|
+
|
|
|
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
|
+
|
|
|
+ params_mapping = [
|
|
|
+ ("model.transformer.ln_f.weight", "model.norm.weight"),
|
|
|
+ ("attn_out", "self_attn.o_proj"),
|
|
|
+ ("att_proj", "self_attn.qkv_proj"),
|
|
|
+ ("q_norm", "self_attn.q_norm"),
|
|
|
+ ("k_norm", "self_attn.k_norm"),
|
|
|
+ ("attn_norm", "input_layernorm"),
|
|
|
+ ("ff_norm", "post_attention_layernorm"),
|
|
|
+ ]
|
|
|
+
|
|
|
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
|
|
|
+
|
|
|
+ embedding_weight = dict()
|
|
|
+ projector_weight = dict()
|
|
|
+ for name, loaded_weight in weights:
|
|
|
+ if "rotary_emb.inv_freq" in name:
|
|
|
+ continue
|
|
|
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
|
|
+ continue
|
|
|
+
|
|
|
+ if "wte.embedding" in name:
|
|
|
+ embedding_weight["embedding"] = loaded_weight
|
|
|
+ continue
|
|
|
+
|
|
|
+ if "wte.new_embedding" in name:
|
|
|
+ embedding_weight["new_embedding"] = loaded_weight
|
|
|
+ continue
|
|
|
+
|
|
|
+ if "vision_backbone" in name:
|
|
|
+ if name.startswith("model"):
|
|
|
+ name = name[len("model."):]
|
|
|
+ if 'image_projector' in name:
|
|
|
+ if 'w1' in name:
|
|
|
+ projector_weight['gate_proj'] = loaded_weight
|
|
|
+ elif 'w3' in name:
|
|
|
+ projector_weight['up_proj'] = loaded_weight
|
|
|
+ elif 'w2' in name:
|
|
|
+ projector_weight['down_proj'] = loaded_weight
|
|
|
+ else:
|
|
|
+ raise ValueError(
|
|
|
+ f"Unexpected projector weight: {name}")
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ if "transformer.blocks" in name:
|
|
|
+ name = name.replace("transformer.blocks", "layers")
|
|
|
+
|
|
|
+ if "ff_proj" in name:
|
|
|
+ name = name.replace("ff_proj", "mlp.gate_up_proj")
|
|
|
+ assert 'weight' in name
|
|
|
+ up_weight, gate_weight = loaded_weight.chunk(2, dim=0)
|
|
|
+ loaded_weight = torch.cat([gate_weight, up_weight], dim=0)
|
|
|
+
|
|
|
+ elif "ff_out" in name:
|
|
|
+ if "layers" in name:
|
|
|
+ name = name.replace("ff_out", "mlp.down_proj")
|
|
|
+ else:
|
|
|
+
|
|
|
+ name = name.replace("model.transformer.ff_out",
|
|
|
+ "lm_head")
|
|
|
+
|
|
|
+ else:
|
|
|
+ for (param_name, weight_name) in params_mapping:
|
|
|
+ if param_name in name:
|
|
|
+ name = name.replace(param_name, weight_name)
|
|
|
+ break
|
|
|
+
|
|
|
+ try:
|
|
|
+
|
|
|
+ if name.endswith(".bias") and name not in params_dict:
|
|
|
+ continue
|
|
|
+ param = params_dict[name]
|
|
|
+ except KeyError:
|
|
|
+ raise ValueError(f"Unexpected weight: {name}") from None
|
|
|
+
|
|
|
+ weight_loader = getattr(param, "weight_loader",
|
|
|
+ default_weight_loader)
|
|
|
+ weight_loader(param, loaded_weight)
|
|
|
+
|
|
|
+ gate_up_proj_weight = torch.cat(
|
|
|
+ [projector_weight["gate_proj"], projector_weight["up_proj"]],
|
|
|
+ dim=0)
|
|
|
+ name = "vision_backbone.image_projector.gate_up_proj.weight"
|
|
|
+ param = params_dict[name]
|
|
|
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
|
+ weight_loader(param, gate_up_proj_weight)
|
|
|
+
|
|
|
+ down_proj_weight = projector_weight["down_proj"]
|
|
|
+ name = "vision_backbone.image_projector.down_proj.weight"
|
|
|
+ param = params_dict[name]
|
|
|
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
|
+ weight_loader(param, down_proj_weight)
|
|
|
+
|
|
|
+ embedding_weight = torch.cat(
|
|
|
+ [embedding_weight["embedding"], embedding_weight["new_embedding"]],
|
|
|
+ dim=0)
|
|
|
+ name = "model.embed_tokens.weight"
|
|
|
+ param = params_dict[name]
|
|
|
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
|
+ weight_loader(param, embedding_weight)
|