minicpmv.py 38 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only MiniCPM-V model compatible with HuggingFace weights."""
  24. import math
  25. import re
  26. from array import array
  27. from functools import partial
  28. from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
  29. TypedDict, Union)
  30. import numpy as np
  31. import torch
  32. import torch.nn.functional as F
  33. import torch.types
  34. from PIL import Image
  35. from torch import nn
  36. from torch.nn.init import trunc_normal_
  37. from transformers.configuration_utils import PretrainedConfig
  38. from aphrodite.attention import AttentionMetadata
  39. from aphrodite.common.config import CacheConfig, MultiModalConfig
  40. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  41. SequenceData)
  42. from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
  43. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  44. from aphrodite.modeling.layers.linear import ReplicatedLinear
  45. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  46. from aphrodite.modeling.layers.sampler import Sampler
  47. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  48. from aphrodite.modeling.model_loader.utils import set_default_torch_dtype
  49. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  50. from aphrodite.modeling.models.interfaces import SupportsMultiModal
  51. from aphrodite.modeling.models.llama import LlamaModel
  52. from aphrodite.modeling.models.minicpm import MiniCPMModel
  53. from aphrodite.modeling.models.qwen2 import Qwen2Model
  54. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  55. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  56. from aphrodite.multimodal.image import (cached_get_image_processor,
  57. cached_get_tokenizer)
  58. from aphrodite.quantization.base_config import QuantizationConfig
  59. from .idefics2_vision_model import Idefics2VisionTransformer
  60. _KEYS_TO_MODIFY_MAPPING = {
  61. "llm.lm_head": "lm_head",
  62. "llm.model": "llm",
  63. }
  64. class MiniCPMVImagePixelInputs(TypedDict):
  65. pixel_values: List[torch.Tensor]
  66. """
  67. Shape: `(batch_size * num_images, num_channels, height, width)`
  68. Note that the image size may vary, so we pass it as a list
  69. instead of a batched tensor.
  70. """
  71. image_bounds: torch.Tensor
  72. """
  73. Shape: `(batch_size * num_images, 2)`
  74. This should be in `(start, stop)` format.
  75. """
  76. tgt_sizes: torch.Tensor
  77. """
  78. Shape: `(batch_size * num_images, 2)`
  79. This should be in `(height, width)` format.
  80. """
  81. MiniCPMVImageInputs = MiniCPMVImagePixelInputs
  82. DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
  83. def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor):
  84. # abs_pos: L, C
  85. # tgt_size: (H, W)
  86. # return: M, C
  87. src_size = int(math.sqrt(abs_pos.size(0)))
  88. # tgt_size = int(math.sqrt(tgt_size))
  89. dtype = abs_pos.dtype
  90. return (F.interpolate(
  91. abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
  92. size=(tgt_size[0], tgt_size[1]),
  93. mode="bicubic",
  94. align_corners=False,
  95. ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype))
  96. # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
  97. def get_2d_sincos_pos_embed(
  98. embed_dim: int,
  99. grid_size: Union[int, Tuple[int, int]],
  100. cls_token: bool = False,
  101. version: Tuple[int, int] = (2, 0),
  102. ):
  103. """
  104. grid_size: int of the grid height and width
  105. return:
  106. pos_embed: [grid_size*grid_size, embed_dim] or
  107. [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
  108. """
  109. if isinstance(grid_size, int):
  110. grid_h_size, grid_w_size = grid_size, grid_size
  111. else:
  112. grid_h_size, grid_w_size = grid_size[0], grid_size[1]
  113. grid_h = np.arange(grid_h_size, dtype=np.float32)
  114. grid_w = np.arange(grid_w_size, dtype=np.float32)
  115. grid = np.meshgrid(grid_w, grid_h) # here w goes first
  116. grid = np.stack(grid, axis=0)
  117. if version == (2, 0):
  118. grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
  119. pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
  120. if cls_token:
  121. pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed],
  122. axis=0)
  123. else:
  124. pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
  125. return pos_embed
  126. def get_2d_sincos_pos_embed_from_grid(embed_dim: int,
  127. grid: np.ndarray,
  128. version: Tuple[int, int] = (2, 0)):
  129. assert embed_dim % 2 == 0
  130. # use half of dimensions to encode grid_h
  131. emb_h = get_1d_sincos_pos_embed_from_grid(
  132. embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2)
  133. emb_w = get_1d_sincos_pos_embed_from_grid(
  134. embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2)
  135. if version == (2, 0):
  136. emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
  137. else:
  138. emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
  139. return emb
  140. def get_1d_sincos_pos_embed_from_grid(embed_dim: int,
  141. pos: np.ndarray,
  142. version: Tuple[int, int] = (2, 0)):
  143. """
  144. embed_dim: output dimension for each position
  145. pos: a list of positions to be encoded: size (M,) / (H, W)
  146. out: (M, D) / (H, W, D)
  147. """
  148. assert embed_dim % 2 == 0
  149. omega = np.arange(embed_dim // 2, dtype=np.float32)
  150. omega /= embed_dim / 2.0
  151. omega = 1.0 / 10000**omega # (D/2,)
  152. if version == (2, 0):
  153. pos = pos.reshape(-1) # (M,)
  154. out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
  155. emb_sin = np.sin(out) # (M, D/2)
  156. emb_cos = np.cos(out) # (M, D/2)
  157. emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
  158. else:
  159. out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
  160. emb_sin = np.sin(out) # (H, W, D/2)
  161. emb_cos = np.cos(out) # (H, W, D/2)
  162. emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
  163. return emb
  164. class BaseResampler(nn.Module):
  165. """
  166. A 2D perceiver-resampler network with one cross attention layers by
  167. (grid_size**2) learnable queries and 2d sincos pos_emb
  168. Outputs:
  169. A tensor with the shape of (grid_size**2, embed_dim)
  170. """
  171. def __init__(
  172. self,
  173. num_queries: int,
  174. embed_dim: int,
  175. num_heads: int,
  176. kv_dim: Optional[int] = None,
  177. norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
  178. ) -> None:
  179. super().__init__()
  180. self.num_queries = num_queries
  181. self.embed_dim = embed_dim
  182. self.num_heads = num_heads
  183. self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
  184. trunc_normal_(self.query, std=0.02)
  185. if kv_dim is not None and kv_dim != embed_dim:
  186. self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
  187. else:
  188. # Maintain the same return value with ReplicatedLinear.forward
  189. self.kv_proj = lambda *args, **kwargs: (
  190. nn.Identity()(*args, **kwargs),
  191. None,
  192. )
  193. self.attn = nn.MultiheadAttention(embed_dim, num_heads)
  194. self.ln_q = norm_layer(embed_dim)
  195. self.ln_kv = norm_layer(embed_dim)
  196. self.ln_post = norm_layer(embed_dim)
  197. self.proj = nn.Parameter(
  198. (embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
  199. def _init_weights(self, m: nn.Module) -> None:
  200. if isinstance(m, nn.Linear):
  201. trunc_normal_(m.weight, std=0.02)
  202. if isinstance(m, nn.Linear) and m.bias is not None:
  203. nn.init.constant_(m.bias, 0)
  204. elif isinstance(m, nn.LayerNorm):
  205. nn.init.constant_(m.bias, 0)
  206. nn.init.constant_(m.weight, 1.0)
  207. def _repeat(self, query, N: int):
  208. return query.unsqueeze(1).repeat(1, N, 1)
  209. class Resampler2(BaseResampler):
  210. def __init__(
  211. self,
  212. grid_size: int,
  213. embed_dim: int,
  214. num_heads: int,
  215. kv_dim: Optional[int] = None,
  216. norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
  217. adaptive: bool = False,
  218. ) -> None:
  219. super().__init__(grid_size**2, embed_dim, num_heads, kv_dim,
  220. norm_layer)
  221. self.adaptive = adaptive
  222. pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
  223. grid_size,
  224. version=(2, 0))
  225. self.pos_embed = nn.Parameter(
  226. torch.from_numpy(pos_embed_arr).float()).requires_grad_(False)
  227. self.apply(self._init_weights)
  228. def forward(
  229. self,
  230. x: torch.Tensor,
  231. tgt_sizes: torch.Tensor,
  232. attn_mask: Optional[torch.Tensor] = None,
  233. ):
  234. if self.adaptive:
  235. pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
  236. tgt_sizes,
  237. version=(2, 0))
  238. pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device,
  239. dtype=x.dtype)
  240. else:
  241. pos_embed = get_abs_pos(self.pos_embed, tgt_sizes)
  242. x, _ = self.kv_proj(x)
  243. x = self.ln_kv(x).permute(1, 0, 2)
  244. N = x.shape[1]
  245. q = self.ln_q(self.query)
  246. out = self.attn(
  247. self._repeat(q, N) + self.pos_embed.unsqueeze(1),
  248. x + pos_embed.unsqueeze(1),
  249. x,
  250. attn_mask=attn_mask,
  251. )[0]
  252. x = out.permute(1, 0, 2)
  253. x = self.ln_post(x)
  254. x = x @ self.proj
  255. return x
  256. class Resampler2_5(BaseResampler):
  257. def __init__(
  258. self,
  259. num_queries: int,
  260. embed_dim: int,
  261. num_heads: int,
  262. kv_dim: Optional[int] = None,
  263. norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
  264. max_size: Tuple[int, int] = (70, 70),
  265. ) -> None:
  266. super().__init__(num_queries, embed_dim, num_heads, kv_dim, norm_layer)
  267. self.max_size = max_size
  268. self._set_2d_pos_cache(self.max_size)
  269. self.apply(self._init_weights)
  270. def _set_2d_pos_cache(self,
  271. max_size: Tuple[int, int],
  272. device: torch.types.Device = "cpu") -> None:
  273. pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
  274. max_size,
  275. version=(2, 5))
  276. pos_embed = torch.from_numpy(pos_embed_arr).float().to(device)
  277. self.register_buffer("pos_embed", pos_embed, persistent=False)
  278. def _adjust_pos_cache(self, tgt_sizes: torch.Tensor,
  279. device: torch.types.Device) -> None:
  280. max_h = tgt_sizes[:, 0].max().item()
  281. max_w = tgt_sizes[:, 1].max().item()
  282. assert isinstance(max_h, int) and isinstance(max_w, int)
  283. if max_h > self.max_size[0] or max_w > self.max_size[1]:
  284. self.max_size = (
  285. max(max_h, self.max_size[0]),
  286. max(max_w, self.max_size[1]),
  287. )
  288. self._set_2d_pos_cache(self.max_size, device)
  289. def forward(self, x: torch.Tensor,
  290. tgt_sizes: torch.Tensor) -> torch.Tensor:
  291. assert x.shape[0] == tgt_sizes.shape[0]
  292. bs = x.shape[0]
  293. device = x.device
  294. dtype = x.dtype
  295. patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
  296. self._adjust_pos_cache(tgt_sizes, device=device)
  297. max_patch_len = patch_len.max().item()
  298. assert isinstance(max_patch_len, int)
  299. key_padding_mask = torch.zeros((bs, max_patch_len),
  300. dtype=torch.bool,
  301. device=device)
  302. pos_embed = []
  303. for i in range(bs):
  304. tgt_h, tgt_w = tgt_sizes[i].tolist()
  305. pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape(
  306. (tgt_h * tgt_w, -1)).to(dtype)) # patches * D
  307. key_padding_mask[i, patch_len[i]:] = True
  308. pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed,
  309. batch_first=True,
  310. padding_value=0.0).permute(
  311. 1, 0,
  312. 2) # BLD => L * B * D
  313. x, _ = self.kv_proj(x) # B * L * D
  314. x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
  315. q = self.ln_q(self.query) # Q * D
  316. out = self.attn(
  317. self._repeat(q, bs), # Q * B * D
  318. x + pos_embed, # L * B * D + L * B * D
  319. x,
  320. key_padding_mask=key_padding_mask,
  321. )[0]
  322. # out: Q * B * D
  323. x = out.permute(1, 0, 2) # B * Q * D
  324. x = self.ln_post(x)
  325. x = x @ self.proj
  326. return x
  327. def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
  328. version_float = getattr(config, "version", None)
  329. # The old configs do not include version number
  330. # TODO: Remove this after the HF repos are updated
  331. if version_float is None:
  332. if config.hidden_size == 2304 and config.query_num == 64:
  333. return (2, 0)
  334. return (2, 5)
  335. version_str = str(version_float)
  336. return tuple(int(x) for x in version_str.split("."))
  337. def get_max_minicpmv_image_tokens(ctx: InputContext):
  338. hf_config = ctx.get_hf_config(PretrainedConfig)
  339. return getattr(hf_config, "query_num", 64)
  340. def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
  341. token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len
  342. return SequenceData(token_ids)
  343. def dummy_image_for_minicpmv(hf_config: PretrainedConfig, num_images: int):
  344. width = height = hf_config.image_size
  345. image = Image.new("RGB", (width, height), color=0)
  346. return {"image": image if num_images == 1 else [image] * num_images}
  347. def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
  348. mm_counts: Mapping[str, int]):
  349. hf_config = ctx.get_hf_config()
  350. num_images = mm_counts["image"]
  351. seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images)
  352. mm_data = dummy_image_for_minicpmv(hf_config, num_images)
  353. return seq_data, mm_data
  354. def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
  355. multi_modal_data = llm_inputs.get("multi_modal_data")
  356. if multi_modal_data is None or "image" not in multi_modal_data:
  357. return llm_inputs
  358. model_config = ctx.model_config
  359. version = get_version_by_config(model_config.hf_config)
  360. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  361. trust_remote_code=True)
  362. image_processor = cached_get_image_processor(model_config.tokenizer)
  363. def get_placeholder(image_size: Tuple[int, int], num_image: int):
  364. if version == (2, 0) or version == (2, 5):
  365. return image_processor. \
  366. get_slice_image_placeholder(image_size)
  367. return image_processor. \
  368. get_slice_image_placeholder(image_size, num_image)
  369. prompt = llm_inputs.get("prompt")
  370. if prompt is None:
  371. token_ids = llm_inputs.get("prompt_token_ids")
  372. prompt = tokenizer.decode(token_ids)
  373. pattern = "(<image>./</image>)"
  374. images = multi_modal_data["image"]
  375. if isinstance(images, Image.Image):
  376. images = [images]
  377. image_tags = re.findall(pattern, prompt)
  378. if len(image_tags) == 0:
  379. new_token_ids = token_ids
  380. new_prompt = prompt
  381. else:
  382. text_chunks = prompt.split(pattern)
  383. new_prompt_chunks: List[str] = []
  384. for i in range(len(images)):
  385. new_prompt_chunks += [
  386. text_chunks[i],
  387. get_placeholder(images[i].size, i)
  388. ]
  389. new_prompt_chunks.append(text_chunks[-1])
  390. new_prompt = "".join(new_prompt_chunks)
  391. new_token_ids = tokenizer.encode(new_prompt)
  392. llm_inputs = LLMInputs(
  393. prompt_token_ids=new_token_ids,
  394. prompt=new_prompt,
  395. multi_modal_data=multi_modal_data,
  396. )
  397. return llm_inputs
  398. class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
  399. """
  400. The abstract class of MiniCPMV can only be inherited, but cannot be
  401. instantiated.
  402. """
  403. def __init__(
  404. self,
  405. config: PretrainedConfig,
  406. multimodal_config: MultiModalConfig,
  407. cache_config: Optional[CacheConfig] = None,
  408. quant_config: Optional[QuantizationConfig] = None,
  409. ):
  410. super().__init__()
  411. self.config = config
  412. self.multimodal_config = multimodal_config
  413. self.version = get_version_by_config(self.config)
  414. self.llm = self.init_llm(config, cache_config, quant_config)
  415. self.vpm = self.init_vision_module()
  416. param_dtype = torch.get_default_dtype()
  417. self.vpm.to(dtype=param_dtype)
  418. self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
  419. self.vpm.embeddings.embed_dim)
  420. self.embed_dim = self.config.hidden_size
  421. self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
  422. self.resampler.to(device="cuda", dtype=param_dtype)
  423. self.lm_head = ParallelLMHead(config.vocab_size,
  424. config.hidden_size,
  425. quant_config=quant_config)
  426. self.logits_processor = LogitsProcessor(config.vocab_size)
  427. self.sampler = Sampler()
  428. def get_embedding(
  429. self,
  430. input_ids: torch.Tensor,
  431. image_inputs: Optional[MiniCPMVImageInputs],
  432. ) -> Tuple[torch.Tensor, torch.Tensor]:
  433. vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
  434. if hasattr(self.config, "scale_emb"):
  435. vlm_embedding *= self.config.scale_emb
  436. if image_inputs is None: # No image
  437. vision_hidden_states = torch.tensor([], device=input_ids.device)
  438. else:
  439. vision_hidden_states = self.get_vision_hidden_states(image_inputs)
  440. # See NOTE in _parse_and_validate_inputs
  441. image_bounds = image_inputs["image_bounds"]
  442. if len(image_bounds) > 0:
  443. image_indices = torch.stack([
  444. torch.arange(start, end, dtype=torch.long)
  445. for start, end in image_bounds.tolist()
  446. ]).to(vlm_embedding.device)
  447. vlm_embedding.scatter_(
  448. 0,
  449. image_indices.view(-1, 1).repeat(1,
  450. vlm_embedding.shape[-1]),
  451. vision_hidden_states.view(-1,
  452. vision_hidden_states.shape[-1]),
  453. )
  454. return vlm_embedding, vision_hidden_states
  455. def _get_image_bounds(self, input_ids: torch.Tensor) -> torch.Tensor:
  456. tokenizer = cached_get_tokenizer(self.config._name_or_path,
  457. trust_remote_code=True)
  458. start_cond = input_ids == tokenizer.im_start_id
  459. end_cond = input_ids == tokenizer.im_end_id
  460. if hasattr(tokenizer, "slice_start_id"):
  461. start_cond |= (input_ids == tokenizer.slice_start_id)
  462. end_cond |= (input_ids == tokenizer.slice_end_id)
  463. image_start_tokens, = torch.where(start_cond)
  464. image_start_tokens += 1
  465. image_end_tokens, = torch.where(end_cond)
  466. valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
  467. if valid_image_nums == 0:
  468. return torch.zeros((0, 2), device=input_ids.device)
  469. return torch.hstack([
  470. image_start_tokens[:valid_image_nums].unsqueeze(-1),
  471. image_end_tokens[:valid_image_nums].unsqueeze(-1),
  472. ])
  473. def _parse_and_validate_inputs(
  474. self,
  475. input_ids: torch.Tensor,
  476. **kwargs: object,
  477. ) -> Optional[MiniCPMVImageInputs]:
  478. pixel_values = kwargs.pop("pixel_values", [])
  479. tgt_sizes = kwargs.pop("tgt_sizes", [])
  480. if not isinstance(pixel_values, (torch.Tensor, list)):
  481. raise ValueError("Incorrect type of pixel values. "
  482. f"Got type: {type(pixel_values)}")
  483. if not isinstance(tgt_sizes, (torch.Tensor, list)):
  484. raise ValueError("Incorrect type of target sizes. "
  485. f"Got type: {type(tgt_sizes)}")
  486. if len(pixel_values) != len(tgt_sizes):
  487. raise ValueError("Inconsistent batch lengths, found: "
  488. f"{len(pixel_values)} vs. {len(tgt_sizes)}")
  489. pixel_values_flat: List[torch.Tensor] = []
  490. tgt_sizes_flat: List[torch.Tensor] = []
  491. for b in range(len(pixel_values)):
  492. pixel_values_flat += pixel_values[b]
  493. tgt_sizes_flat += tgt_sizes[b]
  494. # NOTE: Input IDs does not contain image tokens during memory profiling,
  495. # so we allow it to be empty
  496. if len(pixel_values_flat) != len(tgt_sizes_flat):
  497. raise ValueError("Inconsistent flattened lengths, found: "
  498. f"{len(pixel_values_flat)} vs. "
  499. f"{len(tgt_sizes_flat)}")
  500. if len(pixel_values_flat) == 0:
  501. return None
  502. return MiniCPMVImageInputs(
  503. image_bounds=self._get_image_bounds(input_ids),
  504. pixel_values=pixel_values_flat,
  505. tgt_sizes=torch.stack(tgt_sizes_flat),
  506. )
  507. def forward(
  508. self,
  509. input_ids: torch.Tensor,
  510. positions: torch.Tensor,
  511. kv_caches: List[torch.Tensor],
  512. attn_metadata: AttentionMetadata,
  513. intermediate_tensors: Optional[IntermediateTensors] = None,
  514. **kwargs: Any,
  515. ) -> torch.Tensor:
  516. image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
  517. vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
  518. output = self.llm(
  519. input_ids=None,
  520. positions=positions,
  521. kv_caches=kv_caches,
  522. attn_metadata=attn_metadata,
  523. intermediate_tensors=intermediate_tensors,
  524. inputs_embeds=vlm_embeddings,
  525. )
  526. return output
  527. def compute_logits(
  528. self,
  529. hidden_states: torch.Tensor,
  530. sampling_metadata: SamplingMetadata,
  531. ) -> Optional[torch.Tensor]:
  532. logits = self.logits_processor(self.lm_head, hidden_states,
  533. sampling_metadata)
  534. return logits
  535. def sample(
  536. self,
  537. logits: torch.Tensor,
  538. sampling_metadata: SamplingMetadata,
  539. ) -> Optional[SamplerOutput]:
  540. next_tokens = self.sampler(logits, sampling_metadata)
  541. return next_tokens
  542. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  543. stacked_params_mapping = [
  544. # (param_name, shard_name, shard_id)
  545. ("qkv_proj", "q_proj", "q"),
  546. ("qkv_proj", "k_proj", "k"),
  547. ("qkv_proj", "v_proj", "v"),
  548. ("gate_up_proj", "gate_proj", 0),
  549. ("gate_up_proj", "up_proj", 1),
  550. ]
  551. params_dict = dict(self.named_parameters())
  552. for name, loaded_weight in weights:
  553. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  554. if key_to_modify in name:
  555. name = name.replace(key_to_modify, new_key)
  556. if "rotary_emb.inv_freq" in name:
  557. continue
  558. if ("rotary_emb.cos_cached" in name
  559. or "rotary_emb.sin_cached" in name):
  560. # Models trained using ColossalAI may include these tensors in
  561. # the checkpoint. Skip them.
  562. continue
  563. use_default_weight_loading = False
  564. if self.is_default_weight_loading(name):
  565. use_default_weight_loading = True
  566. else:
  567. for param_name, weight_name, shard_id in stacked_params_mapping:
  568. if weight_name not in name:
  569. continue
  570. param = params_dict[name.replace(weight_name, param_name)]
  571. weight_loader = param.weight_loader
  572. weight_loader(param, loaded_weight, shard_id)
  573. break
  574. else:
  575. use_default_weight_loading = True
  576. if use_default_weight_loading:
  577. param = params_dict[name]
  578. weight_loader = getattr(param, "weight_loader",
  579. default_weight_loader)
  580. weight_loader(param, loaded_weight)
  581. def init_llm(
  582. self,
  583. config: PretrainedConfig,
  584. cache_config: Optional[CacheConfig] = None,
  585. quant_config: Optional[QuantizationConfig] = None,
  586. ) -> nn.Module:
  587. raise NotImplementedError
  588. def init_vision_module(self) -> nn.Module:
  589. raise NotImplementedError
  590. def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
  591. raise NotImplementedError
  592. def get_vision_embedding(
  593. self,
  594. pixel_values: List[torch.Tensor],
  595. patch_attn_mask: Optional[torch.Tensor] = None,
  596. tgt_sizes: Optional[torch.Tensor] = None,
  597. ) -> torch.Tensor:
  598. raise NotImplementedError
  599. def get_vision_hidden_states(self,
  600. data: MiniCPMVImageInputs) -> torch.Tensor:
  601. raise NotImplementedError
  602. def is_default_weight_loading(self, name: str) -> bool:
  603. raise NotImplementedError
  604. class MiniCPMV2(MiniCPMVBaseModel):
  605. def __init__(
  606. self,
  607. config: PretrainedConfig,
  608. multimodal_config: MultiModalConfig,
  609. cache_config: Optional[CacheConfig] = None,
  610. quant_config: Optional[QuantizationConfig] = None,
  611. ):
  612. super().__init__(config, multimodal_config, cache_config, quant_config)
  613. assert self.version == (2, 0)
  614. def init_llm(
  615. self,
  616. config: PretrainedConfig,
  617. cache_config: Optional[CacheConfig] = None,
  618. quant_config: Optional[QuantizationConfig] = None,
  619. ) -> nn.Module:
  620. return MiniCPMModel(config,
  621. cache_config=cache_config,
  622. quant_config=quant_config)
  623. def init_vision_module(self) -> nn.Module:
  624. # TODO :refactor this vision model
  625. try:
  626. import timm
  627. except ImportError:
  628. raise ImportError("Please install timm==0.9.10") from ImportError
  629. with set_default_torch_dtype(torch.float16):
  630. model = timm.create_model(
  631. "vit_so400m_patch14_siglip_384.webli",
  632. pretrained=False,
  633. num_classes=0,
  634. dynamic_img_size=True,
  635. dynamic_img_pad=True,
  636. )
  637. if (isinstance(model, timm.models.VisionTransformer)
  638. and model.attn_pool is not None):
  639. model.attn_pool = torch.nn.Identity()
  640. if self.config.drop_vision_last_layer:
  641. model.blocks = model.blocks[:-1]
  642. return model
  643. def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
  644. with set_default_torch_dtype(torch.float16):
  645. resampler = Resampler2(
  646. embed_dim=embed_dim,
  647. num_heads=embed_dim // 128,
  648. grid_size=int(math.sqrt(self.config.query_num)),
  649. kv_dim=vision_dim,
  650. adaptive=True,
  651. )
  652. return resampler
  653. def get_vision_embedding(
  654. self,
  655. pixel_values: List[torch.Tensor],
  656. patch_attn_mask: Optional[torch.Tensor] = None,
  657. tgt_sizes: Optional[torch.Tensor] = None,
  658. ) -> torch.Tensor:
  659. res = []
  660. dtype = self.vpm.pos_embed.data.dtype
  661. for pixel_value in pixel_values:
  662. H, W = pixel_value[0].shape[-2:]
  663. tgt_size = (
  664. math.ceil(H / self.vpm.patch_embed.patch_size[0]),
  665. math.ceil(W / self.vpm.patch_embed.patch_size[0]),
  666. )
  667. vision_embedding = self.vpm.forward_features(
  668. pixel_value.unsqueeze(0).type(dtype))
  669. if (hasattr(self.vpm, "num_prefix_tokens")
  670. and self.vpm.num_prefix_tokens > 0):
  671. vision_embedding = vision_embedding[:, self.vpm.
  672. num_prefix_tokens:]
  673. res.append(self.resampler(vision_embedding, tgt_size))
  674. return torch.vstack(res)
  675. def get_vision_hidden_states(self,
  676. data: MiniCPMVImageInputs) -> torch.Tensor:
  677. pixel_values = data["pixel_values"]
  678. return self.get_vision_embedding(pixel_values)
  679. def is_default_weight_loading(self, name: str) -> bool:
  680. return "resampler" in name or "vpm" in name
  681. class MiniCPMV2_5(MiniCPMVBaseModel):
  682. def __init__(
  683. self,
  684. config: PretrainedConfig,
  685. multimodal_config: MultiModalConfig,
  686. cache_config: Optional[CacheConfig] = None,
  687. quant_config: Optional[QuantizationConfig] = None,
  688. ):
  689. super().__init__(config, multimodal_config, cache_config, quant_config)
  690. assert self.version == (2, 5)
  691. def init_llm(
  692. self,
  693. config: PretrainedConfig,
  694. cache_config: Optional[CacheConfig] = None,
  695. quant_config: Optional[QuantizationConfig] = None,
  696. ) -> nn.Module:
  697. return LlamaModel(config,
  698. cache_config=cache_config,
  699. quant_config=quant_config)
  700. def init_vision_module(self) -> nn.Module:
  701. model = Idefics2VisionTransformer(self.config.vision_config)
  702. if self.config.drop_vision_last_layer:
  703. model.encoder.layers = model.encoder.layers[:-1]
  704. return model
  705. def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
  706. with set_default_torch_dtype(torch.float16):
  707. resampler = Resampler2_5(
  708. num_queries=self.config.query_num,
  709. embed_dim=embed_dim,
  710. num_heads=embed_dim // 128,
  711. kv_dim=vision_dim,
  712. )
  713. return resampler
  714. def get_vision_embedding(
  715. self,
  716. pixel_values: List[torch.Tensor],
  717. patch_attn_mask: Optional[torch.Tensor] = None,
  718. tgt_sizes: Optional[torch.Tensor] = None,
  719. ) -> torch.Tensor:
  720. vision_embedding = self.vpm(pixel_values,
  721. patch_attention_mask=patch_attn_mask)
  722. vision_embedding = self.resampler(vision_embedding, tgt_sizes)
  723. return vision_embedding
  724. def get_vision_hidden_states(self,
  725. data: MiniCPMVImageInputs) -> torch.Tensor:
  726. pixel_values = data["pixel_values"]
  727. tgt_sizes = data["tgt_sizes"]
  728. device = self.vpm.embeddings.position_embedding.weight.device
  729. dtype = self.vpm.embeddings.position_embedding.weight.dtype
  730. all_pixel_values_lst = [
  731. i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
  732. ]
  733. max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
  734. assert isinstance(max_patches, int)
  735. all_pixel_values = torch.nn.utils.rnn.pad_sequence(
  736. all_pixel_values_lst, batch_first=True, padding_value=0.0)
  737. B, L, _ = all_pixel_values.shape
  738. all_pixel_values = all_pixel_values.permute(0, 2,
  739. 1).reshape(B, 3, -1, L)
  740. patch_attn_mask = torch.zeros((B, 1, max_patches),
  741. dtype=torch.bool,
  742. device=device)
  743. for i in range(B):
  744. patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
  745. return self.get_vision_embedding(all_pixel_values.type(dtype),
  746. patch_attn_mask, tgt_sizes)
  747. def is_default_weight_loading(self, name: str) -> bool:
  748. return "resampler" in name
  749. # NOTE: Currently, information about this model is unavailable. We are
  750. # temporarily using `MiniCPMVQwen2` as it's name. The name may need
  751. # to be modified in the future.
  752. class MiniCPMVQwen2(MiniCPMVBaseModel):
  753. def __init__(
  754. self,
  755. config: PretrainedConfig,
  756. multimodal_config: MultiModalConfig,
  757. cache_config: Optional[CacheConfig] = None,
  758. quant_config: Optional[QuantizationConfig] = None,
  759. ):
  760. super().__init__(config, multimodal_config, cache_config, quant_config)
  761. def init_llm(
  762. self,
  763. config: PretrainedConfig,
  764. cache_config: Optional[CacheConfig] = None,
  765. quant_config: Optional[QuantizationConfig] = None,
  766. ) -> nn.Module:
  767. return Qwen2Model(config,
  768. cache_config=cache_config,
  769. quant_config=quant_config)
  770. def init_vision_module(self) -> nn.Module:
  771. # A custom version of SiglipVisionTransformer, won't work with TP
  772. from aphrodite.modeling.models.na_vit import SiglipVisionTransformer
  773. if self.config._attn_implementation == "flash_attention_2":
  774. self.config.vision_config._attn_implementation = "flash_attention_2"
  775. else:
  776. # not support sdpa
  777. self.config.vision_config._attn_implementation = "eager"
  778. model = SiglipVisionTransformer(self.config.vision_config)
  779. if self.config.drop_vision_last_layer:
  780. model.encoder.layers = model.encoder.layers[:-1]
  781. return model
  782. def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
  783. with set_default_torch_dtype(torch.float16):
  784. resampler = Resampler2_5(
  785. num_queries=self.config.query_num,
  786. embed_dim=embed_dim,
  787. num_heads=embed_dim // 128,
  788. kv_dim=vision_dim,
  789. )
  790. return resampler
  791. def get_vision_embedding(
  792. self,
  793. pixel_values: List[torch.Tensor],
  794. patch_attn_mask: Optional[torch.Tensor] = None,
  795. tgt_sizes: Optional[torch.Tensor] = None,
  796. ) -> torch.Tensor:
  797. vision_embedding = self.vpm(
  798. pixel_values,
  799. patch_attention_mask=patch_attn_mask,
  800. tgt_sizes=tgt_sizes,
  801. ).last_hidden_state
  802. return vision_embedding
  803. def get_vision_hidden_states(self,
  804. data: MiniCPMVImageInputs) -> torch.Tensor:
  805. pixel_values = data["pixel_values"]
  806. tgt_sizes = data["tgt_sizes"]
  807. device = self.vpm.embeddings.position_embedding.weight.device
  808. dtype = self.vpm.embeddings.position_embedding.weight.dtype
  809. all_pixel_values_lst = [
  810. i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
  811. ]
  812. max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
  813. assert isinstance(max_patches, int)
  814. all_pixel_values = torch.nn.utils.rnn.pad_sequence(
  815. all_pixel_values_lst, batch_first=True, padding_value=0.0)
  816. B, L, _ = all_pixel_values.shape
  817. all_pixel_values = all_pixel_values.permute(0, 2,
  818. 1).reshape(B, 3, -1, L)
  819. patch_attn_mask = torch.zeros((B, 1, max_patches),
  820. dtype=torch.bool,
  821. device=device)
  822. for i in range(B):
  823. patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
  824. vision_embedding = self.vpm(
  825. all_pixel_values.type(dtype),
  826. patch_attention_mask=patch_attn_mask,
  827. tgt_sizes=tgt_sizes,
  828. ).last_hidden_state
  829. return self.resampler(vision_embedding, tgt_sizes)
  830. def is_default_weight_loading(self, name: str) -> bool:
  831. return "resampler" in name or "vpm" in name
  832. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  833. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
  834. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
  835. @INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv)
  836. class MiniCPMV(MiniCPMVBaseModel):
  837. """
  838. Different versions of MiniCPMV use different visual encoders and LLMs,
  839. which is not conducive to the current integration logic of LoRA and
  840. bitsandbytes in aphrodite. Therefore, it is necessary to separate them.
  841. """
  842. def __new__(
  843. cls,
  844. config: PretrainedConfig,
  845. multimodal_config: MultiModalConfig,
  846. cache_config: Optional[CacheConfig] = None,
  847. quant_config: Optional[QuantizationConfig] = None,
  848. ):
  849. if not hasattr(config, "version"):
  850. if config.hidden_size == 2304 and config.query_num == 64:
  851. version = (2, 0)
  852. else:
  853. version = (2, 5)
  854. else:
  855. version = str(config.version).split(".")
  856. version = tuple([int(x) for x in version])
  857. # Dispatch class based on version
  858. if version == (2, 0):
  859. instance_class = MiniCPMV2
  860. elif version == (2, 5):
  861. instance_class = MiniCPMV2_5
  862. else:
  863. instance_class = MiniCPMVQwen2
  864. return instance_class(config, multimodal_config, cache_config,
  865. quant_config)