molmo.py 46 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294
  1. import logging
  2. import math
  3. import re
  4. from array import array
  5. from dataclasses import dataclass
  6. from functools import lru_cache, partial
  7. from typing import (Any, Iterable, List, Mapping, Optional, Tuple, TypedDict,
  8. Union)
  9. import torch
  10. from einops import rearrange
  11. from PIL import Image
  12. from torch import nn
  13. from torch.nn import functional as F
  14. from transformers import PretrainedConfig
  15. import aphrodite.common.envs as envs
  16. from aphrodite.attention import Attention, AttentionMetadata
  17. from aphrodite.attention.selector import (_Backend, backend_name_to_enum,
  18. get_global_forced_attn_backend)
  19. from aphrodite.common.config import CacheConfig, MultiModalConfig
  20. from aphrodite.common.logger import log_once
  21. from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
  22. IntermediateTensors, SequenceData)
  23. from aphrodite.distributed import (get_pp_group,
  24. get_tensor_model_parallel_rank,
  25. get_tensor_model_parallel_world_size,
  26. split_tensor_along_last_dim,
  27. tensor_model_parallel_all_gather)
  28. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  29. from aphrodite.modeling.layers.activation import QuickGELU, SiluAndMul
  30. from aphrodite.modeling.layers.layernorm import RMSNorm
  31. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  32. MergedColumnParallelLinear,
  33. QKVParallelLinear,
  34. RowParallelLinear)
  35. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  36. from aphrodite.modeling.layers.rotary_embedding import get_rope
  37. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  38. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  39. ParallelLMHead, VocabParallelEmbedding)
  40. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  41. from aphrodite.modeling.models.interfaces import SupportsMultiModal
  42. from aphrodite.modeling.models.utils import make_layers
  43. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  44. from aphrodite.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
  45. from aphrodite.platforms import current_platform
  46. from aphrodite.quantization.base_config import QuantizationConfig
  47. from aphrodite.transformers_utils.processor import get_processor
  48. log = logging.getLogger(__name__)
  49. # TODO: hard-coded for now. Consider making it configurable.
  50. VIT_LAYERS = [-2, -9]
  51. NUM_PREFIX_TOKENS = 1
  52. ADDITIONAL_VOCAB_SIZE = 128
  53. class MolmoImageInputs(TypedDict):
  54. images: torch.Tensor
  55. """Shape:
  56. `(batch_size, num_crops, num_patch, patch_dim)`
  57. """
  58. image_input_idx: torch.Tensor
  59. """Shape:
  60. `(batch_size, num_crops, num_patch)`
  61. """
  62. seq_len: torch.Tensor
  63. """Shape:
  64. `(batch_size, )`
  65. """
  66. image_masks: Optional[torch.Tensor]
  67. """Shape:
  68. `(batch_size, num_crops, num_patch)`
  69. """
  70. @dataclass
  71. class VisionBackboneConfig:
  72. image_default_input_size: Tuple[int, int] = (336, 336)
  73. image_patch_size: int = 14
  74. image_pos_patch_size: int = 14
  75. image_emb_dim: int = 1024
  76. image_num_heads: int = 16
  77. image_num_key_value_heads: int = 16
  78. image_num_layers: int = 23
  79. image_mlp_dim: int = 4096
  80. image_mlp_activations: str = "quick_gelu"
  81. image_num_pos: int = 577
  82. image_norm_eps: float = 1e-5
  83. def __post_init__(self):
  84. self.image_default_input_size = tuple(
  85. self.image_default_input_size) # type: ignore[assignment]
  86. @property
  87. def image_num_patch(self):
  88. h, w = self.image_default_input_size
  89. return h // self.image_patch_size, w // self.image_patch_size
  90. class ViTMLP(nn.Module):
  91. """MLP used in Vision Transformer."""
  92. def __init__(
  93. self,
  94. config: VisionBackboneConfig,
  95. quant_config: Optional[QuantizationConfig] = None,
  96. ):
  97. super().__init__()
  98. self.w1 = ColumnParallelLinear(
  99. config.image_emb_dim,
  100. config.image_mlp_dim,
  101. bias=True,
  102. quant_config=quant_config,
  103. )
  104. # Activation function.
  105. assert config.image_mlp_activations == "quick_gelu"
  106. self.act = QuickGELU()
  107. self.w2 = RowParallelLinear(
  108. config.image_mlp_dim,
  109. config.image_emb_dim,
  110. bias=True,
  111. quant_config=quant_config,
  112. )
  113. def forward(self, x: torch.Tensor) -> torch.Tensor:
  114. x, _ = self.w1(x)
  115. x = self.act(x)
  116. x, _ = self.w2(x)
  117. return x
  118. class MultiHeadDotProductAttention(nn.Module):
  119. """Multi-head attention used in Vision Transformer."""
  120. def __init__(
  121. self,
  122. config: VisionBackboneConfig,
  123. use_bias: bool = True,
  124. nlayers: int = 1,
  125. quant_config: Optional[QuantizationConfig] = None,
  126. ):
  127. super().__init__()
  128. self.hidden_size = config.image_emb_dim
  129. self.total_num_heads = config.image_num_heads
  130. tp_size = get_tensor_model_parallel_world_size()
  131. assert self.hidden_size % self.total_num_heads == 0
  132. assert self.total_num_heads % tp_size == 0
  133. self.num_heads = self.total_num_heads // tp_size
  134. self.head_dim = self.hidden_size // self.total_num_heads
  135. self.total_num_kv_heads = config.image_num_key_value_heads
  136. if self.total_num_kv_heads >= tp_size:
  137. assert self.total_num_kv_heads % tp_size == 0
  138. else:
  139. assert tp_size % self.total_num_kv_heads == 0
  140. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  141. self.wq = ColumnParallelLinear(
  142. nlayers * self.hidden_size,
  143. self.total_num_heads * self.head_dim,
  144. bias=use_bias,
  145. quant_config=quant_config,
  146. )
  147. self.wk = ColumnParallelLinear(
  148. nlayers * self.hidden_size,
  149. self.total_num_kv_heads * self.head_dim,
  150. bias=use_bias,
  151. quant_config=quant_config,
  152. )
  153. self.wv = ColumnParallelLinear(
  154. nlayers * self.hidden_size,
  155. self.total_num_kv_heads * self.head_dim,
  156. bias=use_bias,
  157. quant_config=quant_config,
  158. )
  159. self.wo = RowParallelLinear(
  160. self.total_num_heads * self.head_dim,
  161. self.hidden_size,
  162. bias=use_bias,
  163. quant_config=quant_config,
  164. )
  165. # Detect attention implementation.
  166. selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
  167. if selected_backend is None:
  168. backend_by_env_var: Optional[str] = envs.APHRODITE_ATTENTION_BACKEND
  169. if backend_by_env_var is not None:
  170. selected_backend = backend_name_to_enum(backend_by_env_var)
  171. if selected_backend is None:
  172. # For Volta and Turing GPUs, use xformers instead.
  173. device_available = current_platform.get_device_capability()[0] >= 8
  174. if device_available:
  175. from transformers.utils import is_flash_attn_2_available
  176. if is_flash_attn_2_available():
  177. self._use_flash_attn = True
  178. else:
  179. log_once(
  180. level="WARNING",
  181. message=
  182. "Current Molmo implementation has a bug with "
  183. "`aphrodite-flash-attn` inside vision module, so we use"
  184. " xformers backend instead. You can run `pip install "
  185. "flash-attn to use flash-attention backend."
  186. )
  187. self._use_flash_attn = False
  188. else:
  189. self._use_flash_attn = False
  190. else:
  191. if selected_backend == _Backend.FLASH_ATTN:
  192. self._use_flash_attn = True
  193. elif selected_backend == _Backend.XFORMERS:
  194. self._use_flash_attn = False
  195. else:
  196. raise RuntimeError(
  197. f"Molmo does not support {selected_backend} backend now.")
  198. def forward(self,
  199. inputs_q: torch.Tensor,
  200. inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor:
  201. if inputs_kv is not None:
  202. inputs_k = inputs_kv
  203. inputs_v = inputs_kv
  204. else:
  205. inputs_k = inputs_q
  206. inputs_v = inputs_q
  207. xq, _ = self.wq(inputs_q)
  208. xk, _ = self.wk(inputs_k)
  209. xv, _ = self.wv(inputs_v)
  210. q_shape = xq.size()[:-1] + (self.num_heads, self.head_dim)
  211. kv_shape = xk.size()[:-1] + (self.num_kv_heads, self.head_dim)
  212. xq = xq.view(*q_shape)
  213. xk = xk.view(*kv_shape)
  214. xv = xv.view(*kv_shape)
  215. if self._use_flash_attn:
  216. from flash_attn import flash_attn_func
  217. output = flash_attn_func(xq, xk, xv, dropout_p=0.0, causal=False)
  218. else:
  219. from xformers import ops as xops
  220. output = xops.memory_efficient_attention_forward(xq, xk, xv, p=0)
  221. output = rearrange(output, "b s h d -> b s (h d)").contiguous()
  222. output, _ = self.wo(output)
  223. return output
  224. class ResidualAttentionBlock(nn.Module):
  225. """Residual attention block used in Vision Transformer."""
  226. def __init__(
  227. self,
  228. config: VisionBackboneConfig,
  229. quant_config: Optional[QuantizationConfig] = None,
  230. ):
  231. super().__init__()
  232. self.attention = MultiHeadDotProductAttention(
  233. config, quant_config=quant_config)
  234. self.feed_forward = ViTMLP(config, quant_config)
  235. self.attention_norm = nn.LayerNorm(
  236. config.image_emb_dim,
  237. eps=config.image_norm_eps,
  238. )
  239. self.ffn_norm = nn.LayerNorm(
  240. config.image_emb_dim,
  241. eps=config.image_norm_eps,
  242. )
  243. def forward(self, x: torch.Tensor) -> torch.Tensor:
  244. x = x + self.attention(self.attention_norm(x))
  245. x = x + self.feed_forward(self.ffn_norm(x))
  246. return x
  247. class BlockCollection(nn.Module):
  248. """Collection of residual attention blocks used in Vision Transformer."""
  249. def __init__(
  250. self,
  251. config: VisionBackboneConfig,
  252. quant_config: Optional[QuantizationConfig] = None,
  253. ):
  254. super().__init__()
  255. self.resblocks = nn.ModuleList([
  256. ResidualAttentionBlock(config, quant_config)
  257. for _ in range(config.image_num_layers)
  258. ])
  259. def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
  260. hidden_states = []
  261. for r in self.resblocks:
  262. x = r(x)
  263. hidden_states.append(x)
  264. return hidden_states
  265. def _expand_token(token: torch.Tensor, batch_size: int) -> torch.Tensor:
  266. return token.view(1, 1, -1).expand(batch_size, -1, -1)
  267. class VisionTransformer(nn.Module):
  268. """Vision Transformer used in Vision Backbone."""
  269. def __init__(
  270. self,
  271. config: VisionBackboneConfig,
  272. quant_config: Optional[QuantizationConfig] = None,
  273. ):
  274. super().__init__()
  275. scale = config.image_emb_dim**-0.5
  276. self.patch_num = config.image_num_patch
  277. self.class_embedding = nn.Parameter(
  278. torch.randn(config.image_emb_dim) * scale)
  279. self.num_prefix_tokens: int = NUM_PREFIX_TOKENS
  280. self.positional_embedding = nn.Parameter(
  281. torch.randn(config.image_num_pos, config.image_emb_dim) * scale)
  282. image_patch_size = config.image_patch_size
  283. self.patch_embedding = nn.Linear(
  284. image_patch_size * image_patch_size * 3,
  285. config.image_emb_dim,
  286. bias=False,
  287. )
  288. self.pre_ln = nn.LayerNorm(config.image_emb_dim,
  289. eps=config.image_norm_eps)
  290. self.transformer = BlockCollection(config, quant_config)
  291. def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
  292. cls_emb = self.positional_embedding[0:1]
  293. pos_emb = self.positional_embedding[1:]
  294. pos_emb = pos_emb.reshape(
  295. (int(math.sqrt(pos_emb.shape[0])),
  296. int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1]))
  297. (patch_num_0, patch_num_1) = patch_num
  298. if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
  299. # from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
  300. pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
  301. pos_emb = F.interpolate(
  302. pos_emb,
  303. size=(patch_num_0, patch_num_1),
  304. mode="bicubic",
  305. align_corners=False,
  306. antialias=True,
  307. )
  308. pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)
  309. pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1])
  310. x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]],
  311. dim=1).to(x.dtype)
  312. return x
  313. def forward(self,
  314. x: torch.Tensor,
  315. patch_num: int = None) -> List[torch.Tensor]:
  316. """
  317. : param x: (batch_size, num_patch, n_pixels)
  318. """
  319. if patch_num is None:
  320. patch_num = self.patch_num
  321. B, N, D = x.shape
  322. x = self.patch_embedding(x)
  323. # class embeddings and positional embeddings
  324. x = torch.cat(
  325. [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x],
  326. dim=1)
  327. x = self.add_pos_emb(x, patch_num)
  328. x = self.pre_ln(x)
  329. hidden_states = self.transformer(x)
  330. return hidden_states
  331. class MolmoAttention(nn.Module):
  332. """Molmo's LLM attention."""
  333. def __init__(
  334. self,
  335. config: PretrainedConfig,
  336. cache_config: Optional[CacheConfig] = None,
  337. quant_config: Optional[QuantizationConfig] = None,
  338. ) -> None:
  339. super().__init__()
  340. self.hidden_size = config.hidden_size
  341. self.tp_size = get_tensor_model_parallel_world_size()
  342. self.total_num_heads = config.num_attention_heads
  343. assert self.hidden_size % self.total_num_heads == 0
  344. assert self.total_num_heads % self.tp_size == 0
  345. self.num_heads = self.total_num_heads // self.tp_size
  346. self.total_num_kv_heads = config.num_key_value_heads \
  347. or self.total_num_heads
  348. if self.total_num_kv_heads >= self.tp_size:
  349. assert self.total_num_kv_heads % self.tp_size == 0
  350. else:
  351. assert self.tp_size % self.total_num_kv_heads == 0
  352. self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
  353. self.head_dim = self.hidden_size // self.total_num_heads
  354. self.q_size = self.num_heads * self.head_dim
  355. self.kv_size = self.num_kv_heads * self.head_dim
  356. self.max_position_embeddings = config.max_position_embeddings
  357. self.rope_theta = config.rope_theta
  358. # Attention input projection. Projects x -> (q, k, v)
  359. self.qkv_proj = QKVParallelLinear(
  360. self.hidden_size,
  361. self.head_dim,
  362. self.total_num_heads,
  363. self.total_num_kv_heads,
  364. bias=config.qkv_bias,
  365. quant_config=quant_config,
  366. )
  367. self.tp_rank: Optional[int] = None
  368. self.k_norm: Optional[nn.Module] = None
  369. self.q_norm: Optional[nn.Module] = None
  370. if config.attention_layer_norm:
  371. self.tp_rank = get_tensor_model_parallel_rank()
  372. self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim,
  373. eps=config.layer_norm_eps)
  374. self.q_norm = RMSNorm(config.hidden_size,
  375. eps=config.layer_norm_eps)
  376. # Rotary embeddings.
  377. self.rotary_emb = get_rope(
  378. self.head_dim,
  379. rotary_dim=self.head_dim,
  380. max_position=self.max_position_embeddings,
  381. base=self.rope_theta,
  382. )
  383. self.scaling = self.head_dim**-0.5
  384. self.attn = Attention(self.num_heads,
  385. self.head_dim,
  386. self.scaling,
  387. num_kv_heads=self.num_kv_heads,
  388. cache_config=cache_config,
  389. quant_config=quant_config)
  390. # Attention output projection.
  391. self.o_proj = RowParallelLinear(
  392. self.total_num_heads * self.head_dim,
  393. self.hidden_size,
  394. bias=False,
  395. quant_config=quant_config,
  396. )
  397. def _apply_qk_norm(self, q: torch.Tensor,
  398. k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  399. if self.tp_size > 1:
  400. q = tensor_model_parallel_all_gather(q.contiguous())
  401. k = tensor_model_parallel_all_gather(k.contiguous())
  402. q = self.q_norm.forward_native(q)
  403. k = self.k_norm.forward_native(k)
  404. if self.tp_size > 1:
  405. splitter = partial(split_tensor_along_last_dim,
  406. num_partitions=self.tp_size)
  407. q = splitter(q)[self.tp_rank]
  408. k = splitter(k)[self.tp_rank]
  409. return q, k
  410. def forward(
  411. self,
  412. positions: torch.Tensor,
  413. hidden_states: torch.Tensor,
  414. kv_cache: torch.Tensor,
  415. attn_metadata: AttentionMetadata,
  416. ) -> torch.Tensor:
  417. qkv, _ = self.qkv_proj(hidden_states)
  418. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  419. if self.q_norm is not None and self.k_norm is not None:
  420. q, k = self._apply_qk_norm(q, k)
  421. q, k = self.rotary_emb(positions, q, k)
  422. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  423. output, _ = self.o_proj(attn_output)
  424. return output
  425. class MolmoMLP(nn.Module):
  426. """Molmo's LLM mlp."""
  427. def __init__(
  428. self,
  429. config: PretrainedConfig,
  430. input_dim: Optional[int] = None,
  431. quant_config: Optional[QuantizationConfig] = None,
  432. ) -> None:
  433. super().__init__()
  434. self.hidden_size = config.hidden_size
  435. self.intermediate_size = config.intermediate_size // 2
  436. # Feed-forward input projection.
  437. self.gate_up_proj = MergedColumnParallelLinear(
  438. input_dim or self.hidden_size,
  439. [self.intermediate_size] * 2,
  440. bias=False,
  441. quant_config=quant_config,
  442. )
  443. # Activation function.
  444. self.act_fn = SiluAndMul()
  445. # Feed-forward output projection.
  446. self.down_proj = RowParallelLinear(
  447. self.intermediate_size,
  448. self.hidden_size,
  449. bias=False,
  450. quant_config=quant_config,
  451. )
  452. def forward(
  453. self,
  454. x: torch.Tensor,
  455. ) -> torch.Tensor:
  456. gate_up, _ = self.gate_up_proj(x)
  457. x = self.act_fn(gate_up)
  458. x, _ = self.down_proj(x)
  459. return x
  460. class MolmoDecoderLayer(nn.Module):
  461. def __init__(
  462. self,
  463. config: PretrainedConfig,
  464. cache_config: Optional[CacheConfig] = None,
  465. quant_config: Optional[QuantizationConfig] = None,
  466. ) -> None:
  467. super().__init__()
  468. # Attention block.
  469. self.self_attn = MolmoAttention(config, cache_config, quant_config)
  470. # MLP block.
  471. self.mlp = MolmoMLP(config, quant_config=quant_config)
  472. # LayerNorm
  473. assert config.layer_norm_type == "rms"
  474. self.input_layernorm = RMSNorm(config.hidden_size,
  475. eps=config.layer_norm_eps)
  476. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  477. eps=config.layer_norm_eps)
  478. def forward(
  479. self,
  480. positions: torch.Tensor,
  481. hidden_states: torch.Tensor,
  482. kv_cache: torch.Tensor,
  483. attn_metadata: AttentionMetadata,
  484. residual: Optional[torch.Tensor],
  485. ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
  486. # Self Attention
  487. if residual is None:
  488. residual = hidden_states
  489. hidden_states = self.input_layernorm(hidden_states)
  490. else:
  491. hidden_states, residual = self.input_layernorm(
  492. hidden_states, residual)
  493. hidden_states = self.self_attn(
  494. positions=positions,
  495. hidden_states=hidden_states,
  496. kv_cache=kv_cache,
  497. attn_metadata=attn_metadata,
  498. )
  499. hidden_states, residual = self.post_attention_layernorm(
  500. hidden_states, residual)
  501. hidden_states = self.mlp(hidden_states)
  502. return hidden_states, residual
  503. class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
  504. def forward(
  505. self,
  506. positions: torch.Tensor,
  507. hidden_states: torch.Tensor,
  508. kv_cache: torch.Tensor,
  509. attn_metadata: AttentionMetadata,
  510. residual: Optional[torch.Tensor],
  511. ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
  512. # Self Attention
  513. residual = hidden_states
  514. hidden_states = self.self_attn(
  515. positions=positions,
  516. hidden_states=hidden_states,
  517. kv_cache=kv_cache,
  518. attn_metadata=attn_metadata,
  519. )
  520. hidden_states = self.input_layernorm(hidden_states)
  521. hidden_states = hidden_states + residual
  522. residual = hidden_states
  523. hidden_states = self.mlp(hidden_states)
  524. hidden_states = self.post_attention_layernorm(hidden_states)
  525. hidden_states = hidden_states + residual
  526. residual = None
  527. return hidden_states, residual
  528. class MolmoVisionBackbone(nn.Module):
  529. def __init__(
  530. self,
  531. config: PretrainedConfig,
  532. vision_config: VisionBackboneConfig,
  533. quant_config: Optional[QuantizationConfig] = None,
  534. ) -> None:
  535. super().__init__()
  536. self.vit_layers = VIT_LAYERS
  537. self.image_num_patch = vision_config.image_num_patch
  538. self.llm_patches_per_crop = (
  539. (self.image_num_patch[0] + 1) // 2,
  540. (self.image_num_patch[1] + 1) // 2,
  541. )
  542. self.image_vit = VisionTransformer(vision_config,
  543. quant_config=quant_config)
  544. self.num_prefix_tokens = self.image_vit.num_prefix_tokens
  545. assert self.num_prefix_tokens in {
  546. 0, 1
  547. }, "Only 0 or 1 prefix tokens are supported"
  548. self.image_pooling_2d = MultiHeadDotProductAttention(
  549. vision_config,
  550. nlayers=len(self.vit_layers),
  551. quant_config=quant_config)
  552. self.image_projector = MolmoMLP(
  553. config,
  554. input_dim=vision_config.image_emb_dim,
  555. quant_config=quant_config,
  556. )
  557. image_dim = vision_config.image_emb_dim * len(self.vit_layers)
  558. self.pad_embed = nn.Parameter(torch.zeros((2, image_dim)))
  559. @property
  560. def dtype(self) -> torch.dtype:
  561. return self.image_vit.patch_embedding.weight.dtype
  562. @property
  563. def device(self) -> torch.device:
  564. return self.image_vit.patch_embedding.weight.device
  565. def encode_image(self, images: torch.Tensor) -> torch.Tensor:
  566. """
  567. : param images: (batch_size, num_crops, num_patch, n_pixels)
  568. """
  569. B, T, N, D = images.shape
  570. mask = ~torch.all(
  571. images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True)
  572. images = images.view(B * T, N, D)
  573. image_features = self.image_vit(images)
  574. if self.vit_layers is not None:
  575. features = []
  576. for layer in self.vit_layers:
  577. features.append(image_features[layer])
  578. image_features = torch.cat(features, dim=-1)
  579. else:
  580. image_features = image_features[-1]
  581. if self.num_prefix_tokens > 0:
  582. image_features = image_features[:, 1:]
  583. image_features = image_features * mask
  584. image_features = image_features.view(B, T, N, -1)
  585. return image_features
  586. def forward(
  587. self, images: torch.Tensor, image_masks: torch.Tensor
  588. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  589. # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501
  590. batch_size, num_image = images.shape[:2]
  591. images = images.to(device=self.device, dtype=self.dtype)
  592. image_features = self.encode_image(images)
  593. og_dtype = image_features.dtype
  594. assert image_masks is not None
  595. pad_embed = self.pad_embed[:, None, None, None, :]
  596. all_pad = image_masks == 0
  597. partial_pad = torch.logical_and(
  598. image_masks < 1,
  599. torch.logical_not(all_pad)).to(dtype=torch.float32)
  600. all_pad = all_pad.to(dtype=torch.float32)
  601. image_features = image_features + pad_embed[0] * torch.unsqueeze(
  602. all_pad, -1)
  603. image_features = image_features + pad_embed[1] * torch.unsqueeze(
  604. partial_pad, -1)
  605. image_features = image_features.to(og_dtype)
  606. image_features = image_features.reshape(
  607. (batch_size, num_image) + self.image_num_patch + (-1, ), )
  608. if self.image_num_patch[0] % 2 == 1:
  609. # Pad so we can still pool 2x2 patches
  610. image_features = F.pad(
  611. image_features,
  612. (0, 0, 0, 1, 0, 1, 0, 0, 0, 0),
  613. )
  614. # image pooling
  615. image_features = rearrange(
  616. image_features,
  617. 'b n (h dh) (w dw) c -> (b n h w) (dh dw) c',
  618. dh=2,
  619. dw=2,
  620. )
  621. query = image_features.mean(-2, keepdim=True)
  622. image_features = self.image_pooling_2d(query, image_features)
  623. h, w = self.llm_patches_per_crop
  624. image_features = image_features.view(batch_size, num_image, h * w, -1)
  625. image_features = self.image_projector(image_features)
  626. # image_features: (batch_size, num_image, num_patch, d_model)
  627. return image_features
  628. class MolmoModel(nn.Module):
  629. def __init__(
  630. self,
  631. config: PretrainedConfig,
  632. cache_config: Optional[CacheConfig] = None,
  633. quant_config: Optional[QuantizationConfig] = None,
  634. prefix: str = "",
  635. ) -> None:
  636. super().__init__()
  637. self.config = config
  638. self.embedding_size = config.embedding_size or config.vocab_size
  639. self.embedding_size += ADDITIONAL_VOCAB_SIZE
  640. self.embed_tokens = VocabParallelEmbedding(
  641. self.embedding_size,
  642. config.hidden_size,
  643. quant_config=quant_config,
  644. )
  645. decoder_layer = MolmoDecoderNormAfterLayer if config.norm_after \
  646. else MolmoDecoderLayer
  647. self.start_layer, self.end_layer, self.layers = make_layers(
  648. config.num_hidden_layers,
  649. lambda prefix: decoder_layer(config, cache_config, quant_config),
  650. prefix=f"{prefix}.layers",
  651. )
  652. assert config.layer_norm_type == "rms"
  653. self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps)
  654. def forward(
  655. self,
  656. input_ids: torch.Tensor,
  657. positions: torch.Tensor,
  658. kv_caches: List[torch.Tensor],
  659. attn_metadata: AttentionMetadata,
  660. intermediate_tensors: Optional[IntermediateTensors] = None,
  661. inputs_embeds: Optional[torch.Tensor] = None,
  662. ) -> torch.Tensor:
  663. if get_pp_group().is_first_rank:
  664. if inputs_embeds is not None:
  665. hidden_states = inputs_embeds
  666. else:
  667. hidden_states = self.embed_tokens(input_ids)
  668. residual = None
  669. else:
  670. assert intermediate_tensors is not None
  671. hidden_states = intermediate_tensors["hidden_states"]
  672. residual = intermediate_tensors["residual"]
  673. # Apply blocks one-by-one.
  674. for i in range(self.start_layer, self.end_layer):
  675. layer = self.layers[i]
  676. hidden_states, residual = layer(
  677. positions,
  678. hidden_states,
  679. kv_caches[i - self.start_layer],
  680. attn_metadata,
  681. residual,
  682. )
  683. if not get_pp_group().is_last_rank:
  684. return IntermediateTensors({
  685. "hidden_states": hidden_states,
  686. "residual": residual
  687. })
  688. if residual is not None:
  689. hidden_states, _ = self.norm(hidden_states, residual)
  690. else:
  691. hidden_states = self.norm(hidden_states)
  692. return hidden_states
  693. cached_get_processor = lru_cache(get_processor)
  694. def get_num_patches(num_tiles: int, crop_patches: int, left_margin: int,
  695. right_margin: int, pooling_size: int) -> int:
  696. crop_window_patches = crop_patches - (left_margin + right_margin)
  697. if num_tiles > 1:
  698. left_crop_window_patches = (crop_window_patches + left_margin +
  699. pooling_size -
  700. 1) // pooling_size * pooling_size
  701. middle_crop_window_patches = (crop_window_patches + pooling_size -
  702. 1) // pooling_size * pooling_size
  703. right_crop_window_patches = (crop_window_patches + right_margin +
  704. pooling_size -
  705. 1) // pooling_size * pooling_size
  706. return left_crop_window_patches + (
  707. num_tiles -
  708. 2) * middle_crop_window_patches + right_crop_window_patches
  709. else:
  710. single_crop_window_patches = (crop_patches + pooling_size -
  711. 1) // pooling_size * pooling_size
  712. return single_crop_window_patches
  713. def get_tokens(tiling_h: int, tiling_w: int, crop_patches: int,
  714. left_margin: int, right_margin: int, pooling_size: int) -> int:
  715. h = get_num_patches(tiling_h, crop_patches, left_margin, right_margin,
  716. pooling_size)
  717. w = get_num_patches(tiling_w, crop_patches, left_margin, right_margin,
  718. pooling_size)
  719. per_row = w // pooling_size + 1
  720. joint = per_row * (h // pooling_size) + 2
  721. image_token_length = (crop_patches + pooling_size - 1) // pooling_size
  722. resize = (image_token_length + 1) * image_token_length + 2
  723. return resize + joint
  724. def get_max_tokens(max_crops: int, crop_patches: int, left_margin: int,
  725. right_margin: int, pooling_size: int) -> int:
  726. tilings = []
  727. for i in range(1, max_crops + 1):
  728. for j in range(1, max_crops + 1):
  729. if i * j <= max_crops:
  730. tilings.append((i, j))
  731. tokens = [
  732. get_tokens(tilings[i][0], tilings[i][1], crop_patches, left_margin,
  733. right_margin, pooling_size) for i in range(len(tilings))
  734. ]
  735. return max(tokens)
  736. def get_max_molmo_image_tokens(ctx: InputContext) -> int:
  737. processor = cached_get_processor(ctx.model_config.model,
  738. trust_remote_code=True,
  739. revision=ctx.model_config.code_revision)
  740. image_processor = processor.image_processor
  741. max_llm_image_tokens = get_max_tokens(
  742. image_processor.max_crops,
  743. image_processor.base_image_input_size[0] //
  744. image_processor.image_patch_size,
  745. image_processor.overlap_margins[0],
  746. image_processor.overlap_margins[1],
  747. 2,
  748. )
  749. return max_llm_image_tokens
  750. # NOTE: preprocessing for the image data has been included in the
  751. # 'input_processor_for_molmo' function
  752. def image_input_mapper_for_molmo(
  753. ctx: InputContext,
  754. data: object,
  755. ):
  756. return MultiModalInputs(data)
  757. def dummy_data_for_molmo(ctx: InputContext, seq_len: int,
  758. mm_counts: Mapping[str, int]):
  759. processor = cached_get_processor(ctx.model_config.model,
  760. trust_remote_code=True,
  761. revision=ctx.model_config.code_revision)
  762. image_processor = processor.image_processor
  763. base_image_input_d = image_processor.image_patch_size
  764. left_margin, right_margin = image_processor.overlap_margins
  765. max_crops = image_processor.max_crops
  766. # Assume: prompt_token_ids always starts with bos_token_id followed image tokens # noqa: E501
  767. max_llm_image_tokens = get_max_molmo_image_tokens(ctx)
  768. if seq_len - max_llm_image_tokens - 1 < 0:
  769. raise RuntimeError(
  770. f"Molmo cannot process {max_crops} crops in a prompt, "
  771. "please increase max_model_len or reduce number of crops")
  772. # The vertical image has the maximum number of image tokens due to column tokens. # noqa: E501
  773. tiling = (max_crops, 1)
  774. total_margin_pixels = base_image_input_d * (right_margin + left_margin)
  775. crop_patches = image_processor.base_image_input_size[
  776. 0] // base_image_input_d
  777. crop_window_patches = crop_patches - (right_margin + left_margin)
  778. crop_window_size = crop_window_patches * base_image_input_d
  779. h = crop_window_size * tiling[0] + total_margin_pixels
  780. w = crop_window_size * tiling[1] + total_margin_pixels
  781. dummy_image = Image.new("RGB", (w, h), color="red")
  782. out = processor.process("dummy prompt", dummy_image)
  783. token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  784. out["input_ids"][:1 + max_llm_image_tokens])
  785. token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  786. [0]) * (seq_len - max_llm_image_tokens - 1)
  787. dummy_seqdata = SequenceData(token_ids)
  788. dummy_imgdata = {
  789. "images": out["images"],
  790. "image_input_idx": out["image_input_idx"],
  791. }
  792. if "image_masks" in out:
  793. dummy_imgdata["image_masks"] = out["image_masks"]
  794. dummy_imgdata["seq_len"] = torch.tensor(seq_len, dtype=torch.long)
  795. return dummy_seqdata, {"image": dummy_imgdata}
  796. def pad_images(
  797. max_total_crops: int,
  798. images: torch.Tensor,
  799. image_input_idx: torch.Tensor,
  800. image_masks: Optional[torch.Tensor] = None,
  801. ):
  802. n = max_total_crops - images.shape[0]
  803. images = F.pad(images, (0, 0, 0, 0, 0, n), value=-1)
  804. image_input_idx = F.pad(image_input_idx, (0, 0, 0, n), value=-1)
  805. if image_masks is not None:
  806. image_masks = F.pad(image_masks, (0, 0, 0, n), value=-1)
  807. return images, image_input_idx, image_masks
  808. def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs):
  809. prompt = llm_inputs["prompt"]
  810. multi_modal_data = llm_inputs.get("multi_modal_data")
  811. image = multi_modal_data.get("image")
  812. processor = cached_get_processor(ctx.model_config.model,
  813. trust_remote_code=True,
  814. revision=ctx.model_config.code_revision)
  815. # NOTE: message formatting for raw text prompt is only applied for
  816. # offline inference; for online inference, the prompt is always in
  817. # instruction format and tokenized.
  818. if prompt is not None and re.match(r"^User:[\s\S]*?(Assistant:)*$",
  819. prompt):
  820. out = processor.process(prompt, image, message_format="none")
  821. elif prompt is not None:
  822. out = processor.process(prompt, image)
  823. else:
  824. out = processor.process(None,
  825. image,
  826. tokens=llm_inputs["prompt_token_ids"])
  827. image_processor = processor.image_processor
  828. max_total_crops = 1 + image_processor.max_crops
  829. if image is not None:
  830. images, image_input_idx, image_masks = pad_images(
  831. max_total_crops,
  832. out["images"],
  833. out["image_input_idx"],
  834. out.get("image_masks"),
  835. )
  836. else:
  837. base_image_input_size = image_processor.base_image_input_size
  838. image_patch_size = image_processor.image_patch_size
  839. image_num_patch = (
  840. base_image_input_size[0] // image_patch_size,
  841. base_image_input_size[1] // image_patch_size,
  842. )
  843. n_pixels = image_patch_size * image_patch_size * 3
  844. n_patches = image_num_patch[0] * image_num_patch[1]
  845. image_length_w = image_processor.image_token_length_w
  846. image_length_h = image_processor.image_token_length_h
  847. tokens_per_image = image_length_w * image_length_h
  848. images = torch.full(
  849. (max_total_crops, n_patches, n_pixels),
  850. -1,
  851. dtype=torch.float32,
  852. )
  853. image_input_idx = torch.full(
  854. (max_total_crops, tokens_per_image),
  855. -1,
  856. dtype=torch.int32,
  857. )
  858. if image_processor.image_padding_mask:
  859. image_masks = torch.full(
  860. (max_total_crops, n_patches),
  861. -1,
  862. dtype=torch.float32,
  863. )
  864. image_data = dict(
  865. images=images,
  866. image_input_idx=image_input_idx,
  867. )
  868. if image_masks is not None:
  869. image_data["image_masks"] = image_masks
  870. image_data["seq_len"] = torch.tensor(len(out["input_ids"]),
  871. dtype=torch.long)
  872. multi_modal_data = dict(image=image_data)
  873. return LLMInputs(
  874. prompt_token_ids=out["input_ids"],
  875. prompt=llm_inputs["prompt"],
  876. multi_modal_data=multi_modal_data,
  877. )
  878. @MULTIMODAL_REGISTRY.register_image_input_mapper(image_input_mapper_for_molmo)
  879. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens)
  880. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo)
  881. @INPUT_REGISTRY.register_input_processor(input_processor_for_molmo)
  882. class MolmoForCausalLM(nn.Module, SupportsMultiModal):
  883. def __init__(
  884. self,
  885. config: PretrainedConfig,
  886. multimodal_config: Optional[MultiModalConfig] = None,
  887. cache_config: Optional[CacheConfig] = None,
  888. quant_config: Optional[Mapping[str, Any]] = None,
  889. ) -> None:
  890. super().__init__()
  891. self.config = config
  892. self.multimodal_config = multimodal_config
  893. vision_config = VisionBackboneConfig()
  894. self.vision_backbone = MolmoVisionBackbone(config, vision_config,
  895. quant_config)
  896. self.model = MolmoModel(config, cache_config, quant_config)
  897. if self.config.weight_tying:
  898. self.lm_head = self.model.transformer.wte
  899. else:
  900. self.lm_head = ParallelLMHead(
  901. config.embedding_size or config.vocab_size,
  902. config.hidden_size,
  903. quant_config=quant_config,
  904. )
  905. self.logits_processor = LogitsProcessor(config.embedding_size
  906. or config.vocab_size)
  907. self.sampler = Sampler()
  908. def _parse_and_validate_image_input(
  909. self,
  910. **kwargs: object,
  911. ) -> Optional[MolmoImageInputs]:
  912. images = kwargs.pop("images", None)
  913. image_masks = kwargs.pop("image_masks", None)
  914. if images is None:
  915. return None
  916. image_input_idx = kwargs.pop("image_input_idx", None)
  917. seq_len = kwargs.pop("seq_len", None)
  918. if image_input_idx is None:
  919. raise ValueError("image_input_idx is required for Molmo model.")
  920. if seq_len is None:
  921. raise ValueError("seq_len is required for Molmo model.")
  922. if not isinstance(seq_len, torch.Tensor):
  923. seq_len = torch.tensor(seq_len)
  924. return MolmoImageInputs(
  925. images=images,
  926. image_input_idx=image_input_idx,
  927. seq_len=seq_len,
  928. image_masks=image_masks,
  929. )
  930. def _process_image_input(
  931. self,
  932. image_input: MolmoImageInputs,
  933. ) -> torch.Tensor:
  934. image_features = self.vision_backbone(
  935. images=image_input["images"],
  936. image_masks=image_input["image_masks"],
  937. )
  938. return image_features
  939. def _merge_multimodal_embeddings(
  940. self,
  941. inputs_embeds: torch.Tensor,
  942. image_features: torch.Tensor,
  943. image_input_idx: torch.Tensor,
  944. seq_len: Union[torch.Tensor, List[torch.Tensor]],
  945. ) -> torch.Tensor:
  946. batch_size, num_image, num_patch = image_features.shape[:3]
  947. assert image_input_idx.shape == (batch_size, num_image, num_patch)
  948. image_features = image_features.to(inputs_embeds.device)
  949. seq_len = seq_len.to(inputs_embeds.device)
  950. # insert the image feature into the embedding.
  951. image_features = image_features.view(batch_size, num_image * num_patch,
  952. -1)
  953. image_input_idx = image_input_idx.view(batch_size,
  954. num_image * num_patch)
  955. valid = image_input_idx >= 0
  956. image_features = image_features * valid[:, :, None].to(
  957. image_features.dtype)
  958. image_features = image_features.view(
  959. batch_size * num_image * num_patch, -1).contiguous()
  960. image_input_idx = image_input_idx * valid.to(image_input_idx.dtype)
  961. offset = torch.cat(
  962. [seq_len.new_zeros(
  963. (1)), seq_len.cumsum(dim=0)[:-1]], dim=0)[:, None]
  964. image_input_idx = image_input_idx + offset.to(image_input_idx.dtype)
  965. image_input_idx = image_input_idx.flatten()[:, None]
  966. mat = image_input_idx == torch.arange(
  967. seq_len.sum().item(), device=inputs_embeds.device)[None, :]
  968. mat = mat.to(image_features.dtype)
  969. inputs_embeds = inputs_embeds + torch.einsum('nd,nm->md',
  970. image_features, mat)
  971. return inputs_embeds
  972. def forward(
  973. self,
  974. input_ids: torch.LongTensor,
  975. positions: torch.LongTensor,
  976. kv_caches: List[torch.Tensor],
  977. attn_metadata: AttentionMetadata,
  978. **kwargs: object,
  979. ) -> SamplerOutput:
  980. image_input = self._parse_and_validate_image_input(**kwargs)
  981. if image_input is not None:
  982. inputs_embeds = self.model.embed_tokens(input_ids)
  983. image_features = self._process_image_input(image_input)
  984. inputs_embeds = self._merge_multimodal_embeddings(
  985. inputs_embeds,
  986. image_features,
  987. image_input["image_input_idx"],
  988. image_input["seq_len"],
  989. )
  990. input_ids = None
  991. else:
  992. inputs_embeds = None
  993. hidden_states = self.model(
  994. input_ids=input_ids,
  995. positions=positions,
  996. kv_caches=kv_caches,
  997. attn_metadata=attn_metadata,
  998. inputs_embeds=inputs_embeds,
  999. )
  1000. return hidden_states
  1001. def compute_logits(self, hidden_states: torch.Tensor,
  1002. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  1003. logits = self.logits_processor(self.lm_head, hidden_states,
  1004. sampling_metadata)
  1005. return logits
  1006. def sample(
  1007. self,
  1008. logits: torch.Tensor,
  1009. sampling_metadata: SamplingMetadata,
  1010. ) -> Optional[SamplerOutput]:
  1011. next_tokens = self.sampler(logits, sampling_metadata)
  1012. return next_tokens
  1013. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  1014. params_mapping = [
  1015. ("model.transformer.ln_f.weight", "model.norm.weight"),
  1016. ("attn_out", "self_attn.o_proj"),
  1017. ("att_proj", "self_attn.qkv_proj"),
  1018. ("q_norm", "self_attn.q_norm"),
  1019. ("k_norm", "self_attn.k_norm"),
  1020. ("attn_norm", "input_layernorm"),
  1021. ("ff_norm", "post_attention_layernorm"),
  1022. ]
  1023. params_dict = dict(self.named_parameters(remove_duplicate=False))
  1024. embedding_weight = dict()
  1025. projector_weight = dict()
  1026. for name, loaded_weight in weights:
  1027. if "rotary_emb.inv_freq" in name:
  1028. continue
  1029. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  1030. continue
  1031. if "wte.embedding" in name:
  1032. embedding_weight["embedding"] = loaded_weight
  1033. continue
  1034. if "wte.new_embedding" in name:
  1035. embedding_weight["new_embedding"] = loaded_weight
  1036. continue
  1037. if "vision_backbone" in name:
  1038. if name.startswith("model"):
  1039. name = name[len("model."):]
  1040. if 'image_projector' in name:
  1041. if 'w1' in name:
  1042. projector_weight['gate_proj'] = loaded_weight
  1043. elif 'w3' in name:
  1044. projector_weight['up_proj'] = loaded_weight
  1045. elif 'w2' in name:
  1046. projector_weight['down_proj'] = loaded_weight
  1047. else:
  1048. raise ValueError(
  1049. f"Unexpected projector weight: {name}")
  1050. continue
  1051. else:
  1052. if "transformer.blocks" in name:
  1053. name = name.replace("transformer.blocks", "layers")
  1054. if "ff_proj" in name:
  1055. name = name.replace("ff_proj", "mlp.gate_up_proj")
  1056. assert 'weight' in name
  1057. up_weight, gate_weight = loaded_weight.chunk(2, dim=0)
  1058. loaded_weight = torch.cat([gate_weight, up_weight], dim=0)
  1059. elif "ff_out" in name:
  1060. if "layers" in name:
  1061. name = name.replace("ff_out", "mlp.down_proj")
  1062. else:
  1063. # lm head
  1064. name = name.replace("model.transformer.ff_out",
  1065. "lm_head")
  1066. else:
  1067. for (param_name, weight_name) in params_mapping:
  1068. if param_name in name:
  1069. name = name.replace(param_name, weight_name)
  1070. break
  1071. try:
  1072. # Skip loading extra bias for GPTQ models.
  1073. if name.endswith(".bias") and name not in params_dict:
  1074. continue
  1075. param = params_dict[name]
  1076. except KeyError:
  1077. raise ValueError(f"Unexpected weight: {name}") from None
  1078. weight_loader = getattr(param, "weight_loader",
  1079. default_weight_loader)
  1080. weight_loader(param, loaded_weight)
  1081. gate_up_proj_weight = torch.cat(
  1082. [projector_weight["gate_proj"], projector_weight["up_proj"]],
  1083. dim=0)
  1084. name = "vision_backbone.image_projector.gate_up_proj.weight"
  1085. param = params_dict[name]
  1086. weight_loader = getattr(param, "weight_loader", default_weight_loader)
  1087. weight_loader(param, gate_up_proj_weight)
  1088. down_proj_weight = projector_weight["down_proj"]
  1089. name = "vision_backbone.image_projector.down_proj.weight"
  1090. param = params_dict[name]
  1091. weight_loader = getattr(param, "weight_loader", default_weight_loader)
  1092. weight_loader(param, down_proj_weight)
  1093. embedding_weight = torch.cat(
  1094. [embedding_weight["embedding"], embedding_weight["new_embedding"]],
  1095. dim=0)
  1096. name = "model.embed_tokens.weight"
  1097. param = params_dict[name]
  1098. weight_loader = getattr(param, "weight_loader", default_weight_loader)
  1099. weight_loader(param, embedding_weight)