minicpmv.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only MiniCPM-V model compatible with HuggingFace weights."""
  24. import math
  25. import re
  26. from functools import partial
  27. from typing import (Any, Callable, Iterable, List, Optional, Tuple, TypedDict,
  28. Union)
  29. import numpy as np
  30. import torch
  31. import torch.nn.functional as F
  32. import torch.types
  33. from PIL import Image
  34. from torch import nn
  35. from torch.nn.init import trunc_normal_
  36. from transformers.configuration_utils import PretrainedConfig
  37. from aphrodite.attention import AttentionMetadata
  38. from aphrodite.common.config import CacheConfig, MultiModalConfig
  39. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  40. SequenceData)
  41. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  42. from aphrodite.modeling.layers.linear import ReplicatedLinear
  43. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  44. from aphrodite.modeling.layers.sampler import Sampler
  45. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  46. from aphrodite.modeling.model_loader.utils import set_default_torch_dtype
  47. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  48. from aphrodite.modeling.models.interfaces import SupportsVision
  49. from aphrodite.modeling.models.llama import LlamaModel
  50. from aphrodite.modeling.models.minicpm import MiniCPMModel
  51. from aphrodite.modeling.models.qwen2 import Qwen2Model
  52. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  53. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  54. from aphrodite.multimodal.image import (cached_get_image_processor,
  55. cached_get_tokenizer)
  56. from aphrodite.quantization.base_config import QuantizationConfig
  57. from .idefics2_vision_model import Idefics2VisionTransformer
  58. _KEYS_TO_MODIFY_MAPPING = {
  59. "llm.lm_head": "lm_head",
  60. "llm.model": "llm",
  61. }
  62. class MiniCPMVImagePixelInputs(TypedDict):
  63. pixel_values: List[torch.Tensor]
  64. """
  65. Shape: `(batch_size * num_images, num_channels, height, width)`
  66. Note that the image size may vary, so we pass it as a list
  67. instead of a batched tensor.
  68. """
  69. image_bounds: torch.Tensor
  70. """
  71. Shape: `(batch_size * num_images, 2)`
  72. This should be in `(start, stop)` format.
  73. """
  74. tgt_sizes: torch.Tensor
  75. """
  76. Shape: `(batch_size * num_images, 2)`
  77. This should be in `(height, width)` format.
  78. """
  79. MiniCPMVImageInputs = MiniCPMVImagePixelInputs
  80. DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
  81. def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor):
  82. # abs_pos: L, C
  83. # tgt_size: (H, W)
  84. # return: M, C
  85. src_size = int(math.sqrt(abs_pos.size(0)))
  86. # tgt_size = int(math.sqrt(tgt_size))
  87. dtype = abs_pos.dtype
  88. return (F.interpolate(
  89. abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
  90. size=(tgt_size[0], tgt_size[1]),
  91. mode="bicubic",
  92. align_corners=False,
  93. ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype))
  94. # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
  95. def get_2d_sincos_pos_embed(
  96. embed_dim: int,
  97. grid_size: Union[int, Tuple[int, int]],
  98. cls_token: bool = False,
  99. version: Tuple[int, int] = (2, 0),
  100. ):
  101. """
  102. grid_size: int of the grid height and width
  103. return:
  104. pos_embed: [grid_size*grid_size, embed_dim] or
  105. [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
  106. """
  107. if isinstance(grid_size, int):
  108. grid_h_size, grid_w_size = grid_size, grid_size
  109. else:
  110. grid_h_size, grid_w_size = grid_size[0], grid_size[1]
  111. grid_h = np.arange(grid_h_size, dtype=np.float32)
  112. grid_w = np.arange(grid_w_size, dtype=np.float32)
  113. grid = np.meshgrid(grid_w, grid_h) # here w goes first
  114. grid = np.stack(grid, axis=0)
  115. if version == (2, 0):
  116. grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
  117. pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
  118. if cls_token:
  119. pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed],
  120. axis=0)
  121. else:
  122. pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
  123. return pos_embed
  124. def get_2d_sincos_pos_embed_from_grid(embed_dim: int,
  125. grid: np.ndarray,
  126. version: Tuple[int, int] = (2, 0)):
  127. assert embed_dim % 2 == 0
  128. # use half of dimensions to encode grid_h
  129. emb_h = get_1d_sincos_pos_embed_from_grid(
  130. embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2)
  131. emb_w = get_1d_sincos_pos_embed_from_grid(
  132. embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2)
  133. if version == (2, 0):
  134. emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
  135. else:
  136. emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
  137. return emb
  138. def get_1d_sincos_pos_embed_from_grid(embed_dim: int,
  139. pos: np.ndarray,
  140. version: Tuple[int, int] = (2, 0)):
  141. """
  142. embed_dim: output dimension for each position
  143. pos: a list of positions to be encoded: size (M,) / (H, W)
  144. out: (M, D) / (H, W, D)
  145. """
  146. assert embed_dim % 2 == 0
  147. omega = np.arange(embed_dim // 2, dtype=np.float32)
  148. omega /= embed_dim / 2.0
  149. omega = 1.0 / 10000**omega # (D/2,)
  150. if version == (2, 0):
  151. pos = pos.reshape(-1) # (M,)
  152. out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
  153. emb_sin = np.sin(out) # (M, D/2)
  154. emb_cos = np.cos(out) # (M, D/2)
  155. emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
  156. else:
  157. out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
  158. emb_sin = np.sin(out) # (H, W, D/2)
  159. emb_cos = np.cos(out) # (H, W, D/2)
  160. emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
  161. return emb
  162. class BaseResampler(nn.Module):
  163. """
  164. A 2D perceiver-resampler network with one cross attention layers by
  165. (grid_size**2) learnable queries and 2d sincos pos_emb
  166. Outputs:
  167. A tensor with the shape of (grid_size**2, embed_dim)
  168. """
  169. def __init__(
  170. self,
  171. num_queries: int,
  172. embed_dim: int,
  173. num_heads: int,
  174. kv_dim: Optional[int] = None,
  175. norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
  176. ) -> None:
  177. super().__init__()
  178. self.num_queries = num_queries
  179. self.embed_dim = embed_dim
  180. self.num_heads = num_heads
  181. self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
  182. trunc_normal_(self.query, std=0.02)
  183. if kv_dim is not None and kv_dim != embed_dim:
  184. self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
  185. else:
  186. # Maintain the same return value with ReplicatedLinear.forward
  187. self.kv_proj = lambda *args, **kwargs: (
  188. nn.Identity()(*args, **kwargs),
  189. None,
  190. )
  191. self.attn = nn.MultiheadAttention(embed_dim, num_heads)
  192. self.ln_q = norm_layer(embed_dim)
  193. self.ln_kv = norm_layer(embed_dim)
  194. self.ln_post = norm_layer(embed_dim)
  195. self.proj = nn.Parameter(
  196. (embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
  197. def _init_weights(self, m: nn.Module) -> None:
  198. if isinstance(m, nn.Linear):
  199. trunc_normal_(m.weight, std=0.02)
  200. if isinstance(m, nn.Linear) and m.bias is not None:
  201. nn.init.constant_(m.bias, 0)
  202. elif isinstance(m, nn.LayerNorm):
  203. nn.init.constant_(m.bias, 0)
  204. nn.init.constant_(m.weight, 1.0)
  205. def _repeat(self, query, N: int):
  206. return query.unsqueeze(1).repeat(1, N, 1)
  207. class Resampler2(BaseResampler):
  208. def __init__(
  209. self,
  210. grid_size: int,
  211. embed_dim: int,
  212. num_heads: int,
  213. kv_dim: Optional[int] = None,
  214. norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
  215. adaptive: bool = False,
  216. ) -> None:
  217. super().__init__(grid_size**2, embed_dim, num_heads, kv_dim,
  218. norm_layer)
  219. self.adaptive = adaptive
  220. pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
  221. grid_size,
  222. version=(2, 0))
  223. self.pos_embed = nn.Parameter(
  224. torch.from_numpy(pos_embed_arr).float()).requires_grad_(False)
  225. self.apply(self._init_weights)
  226. def forward(
  227. self,
  228. x: torch.Tensor,
  229. tgt_sizes: torch.Tensor,
  230. attn_mask: Optional[torch.Tensor] = None,
  231. ):
  232. if self.adaptive:
  233. pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
  234. tgt_sizes,
  235. version=(2, 0))
  236. pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device,
  237. dtype=x.dtype)
  238. else:
  239. pos_embed = get_abs_pos(self.pos_embed, tgt_sizes)
  240. x, _ = self.kv_proj(x)
  241. x = self.ln_kv(x).permute(1, 0, 2)
  242. N = x.shape[1]
  243. q = self.ln_q(self.query)
  244. out = self.attn(
  245. self._repeat(q, N) + self.pos_embed.unsqueeze(1),
  246. x + pos_embed.unsqueeze(1),
  247. x,
  248. attn_mask=attn_mask,
  249. )[0]
  250. x = out.permute(1, 0, 2)
  251. x = self.ln_post(x)
  252. x = x @ self.proj
  253. return x
  254. class Resampler2_5(BaseResampler):
  255. def __init__(
  256. self,
  257. num_queries: int,
  258. embed_dim: int,
  259. num_heads: int,
  260. kv_dim: Optional[int] = None,
  261. norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
  262. max_size: Tuple[int, int] = (70, 70),
  263. ) -> None:
  264. super().__init__(num_queries, embed_dim, num_heads, kv_dim, norm_layer)
  265. self.max_size = max_size
  266. self._set_2d_pos_cache(self.max_size)
  267. self.apply(self._init_weights)
  268. def _set_2d_pos_cache(self,
  269. max_size: Tuple[int, int],
  270. device: torch.types.Device = "cpu") -> None:
  271. pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
  272. max_size,
  273. version=(2, 5))
  274. pos_embed = torch.from_numpy(pos_embed_arr).float().to(device)
  275. self.register_buffer("pos_embed", pos_embed, persistent=False)
  276. def _adjust_pos_cache(self, tgt_sizes: torch.Tensor,
  277. device: torch.types.Device) -> None:
  278. max_h = tgt_sizes[:, 0].max().item()
  279. max_w = tgt_sizes[:, 1].max().item()
  280. assert isinstance(max_h, int) and isinstance(max_w, int)
  281. if max_h > self.max_size[0] or max_w > self.max_size[1]:
  282. self.max_size = (
  283. max(max_h, self.max_size[0]),
  284. max(max_w, self.max_size[1]),
  285. )
  286. self._set_2d_pos_cache(self.max_size, device)
  287. def forward(self, x: torch.Tensor,
  288. tgt_sizes: torch.Tensor) -> torch.Tensor:
  289. assert x.shape[0] == tgt_sizes.shape[0]
  290. bs = x.shape[0]
  291. device = x.device
  292. dtype = x.dtype
  293. patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
  294. self._adjust_pos_cache(tgt_sizes, device=device)
  295. max_patch_len = patch_len.max().item()
  296. assert isinstance(max_patch_len, int)
  297. key_padding_mask = torch.zeros((bs, max_patch_len),
  298. dtype=torch.bool,
  299. device=device)
  300. pos_embed = []
  301. for i in range(bs):
  302. tgt_h, tgt_w = tgt_sizes[i].tolist()
  303. pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape(
  304. (tgt_h * tgt_w, -1)).to(dtype)) # patches * D
  305. key_padding_mask[i, patch_len[i]:] = True
  306. pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed,
  307. batch_first=True,
  308. padding_value=0.0).permute(
  309. 1, 0,
  310. 2) # BLD => L * B * D
  311. x, _ = self.kv_proj(x) # B * L * D
  312. x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
  313. q = self.ln_q(self.query) # Q * D
  314. out = self.attn(
  315. self._repeat(q, bs), # Q * B * D
  316. x + pos_embed, # L * B * D + L * B * D
  317. x,
  318. key_padding_mask=key_padding_mask,
  319. )[0]
  320. # out: Q * B * D
  321. x = out.permute(1, 0, 2) # B * Q * D
  322. x = self.ln_post(x)
  323. x = x @ self.proj
  324. return x
  325. def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
  326. version_float = getattr(config, "version", None)
  327. # The old configs do not include version number
  328. # TODO: Remove this after the HF repos are updated
  329. if version_float is None:
  330. if config.hidden_size == 2304 and config.query_num == 64:
  331. return (2, 0)
  332. return (2, 5)
  333. version_str = str(version_float)
  334. return tuple(int(x) for x in version_str.split("."))
  335. def get_max_minicpmv_image_tokens(ctx: InputContext):
  336. hf_config = ctx.get_hf_config(PretrainedConfig)
  337. return getattr(hf_config, "query_num", 64)
  338. def dummy_seq_data_for_minicpmv(seq_len: int):
  339. token_ids = [0] * seq_len
  340. return SequenceData(token_ids)
  341. def dummy_image_for_minicpmv(hf_config: PretrainedConfig):
  342. width = height = hf_config.image_size
  343. image = Image.new("RGB", (width, height), color=0)
  344. return {"image": image}
  345. def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int):
  346. hf_config = ctx.get_hf_config(PretrainedConfig)
  347. seq_data = dummy_seq_data_for_minicpmv(seq_len)
  348. mm_data = dummy_image_for_minicpmv(hf_config)
  349. return seq_data, mm_data
  350. def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
  351. multi_modal_data = llm_inputs.get("multi_modal_data")
  352. if multi_modal_data is None or "image" not in multi_modal_data:
  353. return llm_inputs
  354. model_config = ctx.model_config
  355. version = get_version_by_config(model_config.hf_config)
  356. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  357. trust_remote_code=True)
  358. image_processor = cached_get_image_processor(model_config.tokenizer)
  359. def get_placeholder(image_size: Tuple[int, int], num_image: int):
  360. if version == (2, 0) or version == (2, 5):
  361. return image_processor. \
  362. get_slice_image_placeholder(image_size)
  363. return image_processor. \
  364. get_slice_image_placeholder(image_size, num_image)
  365. prompt = llm_inputs.get("prompt")
  366. if prompt is None:
  367. token_ids = llm_inputs.get("prompt_token_ids")
  368. prompt = tokenizer.decode(token_ids)
  369. pattern = "(<image>./</image>)"
  370. images = multi_modal_data["image"]
  371. if isinstance(images, Image.Image):
  372. images = [images]
  373. image_tags = re.findall(pattern, prompt)
  374. if len(image_tags) == 0:
  375. new_token_ids = token_ids
  376. new_prompt = prompt
  377. else:
  378. text_chunks = prompt.split(pattern)
  379. new_prompt_chunks: List[str] = []
  380. for i in range(len(images)):
  381. new_prompt_chunks += [
  382. text_chunks[i],
  383. get_placeholder(images[i].size, i)
  384. ]
  385. new_prompt_chunks.append(text_chunks[-1])
  386. new_prompt = "".join(new_prompt_chunks)
  387. new_token_ids = tokenizer.encode(new_prompt)
  388. llm_inputs = LLMInputs(
  389. prompt_token_ids=new_token_ids,
  390. prompt=new_prompt,
  391. multi_modal_data=multi_modal_data,
  392. )
  393. return llm_inputs
  394. class MiniCPMVBaseModel(nn.Module, SupportsVision):
  395. """
  396. The abstract class of MiniCPMV can only be inherited, but cannot be
  397. instantiated.
  398. """
  399. def __init__(
  400. self,
  401. config: PretrainedConfig,
  402. multimodal_config: MultiModalConfig,
  403. cache_config: Optional[CacheConfig] = None,
  404. quant_config: Optional[QuantizationConfig] = None,
  405. ):
  406. super().__init__()
  407. self.config = config
  408. self.multimodal_config = multimodal_config
  409. self.version = get_version_by_config(self.config)
  410. self.llm = self.init_llm(config, cache_config, quant_config)
  411. self.vpm = self.init_vision_module()
  412. param_dtype = torch.get_default_dtype()
  413. self.vpm.to(dtype=param_dtype)
  414. self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
  415. self.vpm.embeddings.embed_dim)
  416. self.embed_dim = self.config.hidden_size
  417. self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
  418. self.resampler.to(device="cuda", dtype=param_dtype)
  419. self.lm_head = ParallelLMHead(config.vocab_size,
  420. config.hidden_size,
  421. quant_config=quant_config)
  422. self.logits_processor = LogitsProcessor(config.vocab_size)
  423. self.sampler = Sampler()
  424. def get_embedding(
  425. self,
  426. input_ids: torch.Tensor,
  427. image_inputs: Optional[MiniCPMVImageInputs],
  428. ) -> Tuple[torch.Tensor, torch.Tensor]:
  429. vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
  430. if hasattr(self.config, "scale_emb"):
  431. vlm_embedding *= self.config.scale_emb
  432. if image_inputs is None: # No image
  433. vision_hidden_states = torch.tensor([], device=input_ids.device)
  434. else:
  435. vision_hidden_states = self.get_vision_hidden_states(image_inputs)
  436. # See NOTE in _parse_and_validate_inputs
  437. image_bounds = image_inputs["image_bounds"]
  438. if len(image_bounds) > 0:
  439. image_indices = torch.stack([
  440. torch.arange(start, end, dtype=torch.long)
  441. for start, end in image_bounds.tolist()
  442. ]).to(vlm_embedding.device)
  443. vlm_embedding.scatter_(
  444. 0,
  445. image_indices.view(-1, 1).repeat(1,
  446. vlm_embedding.shape[-1]),
  447. vision_hidden_states.view(-1,
  448. vision_hidden_states.shape[-1]),
  449. )
  450. return vlm_embedding, vision_hidden_states
  451. def _get_image_bounds(self, input_ids: torch.Tensor) -> torch.Tensor:
  452. tokenizer = cached_get_tokenizer(self.config._name_or_path,
  453. trust_remote_code=True)
  454. start_cond = input_ids == tokenizer.im_start_id
  455. end_cond = input_ids == tokenizer.im_end_id
  456. if hasattr(tokenizer, "slice_start_id"):
  457. start_cond |= (input_ids == tokenizer.slice_start_id)
  458. end_cond |= (input_ids == tokenizer.slice_end_id)
  459. image_start_tokens, = torch.where(start_cond)
  460. image_start_tokens += 1
  461. image_end_tokens, = torch.where(end_cond)
  462. valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
  463. if valid_image_nums == 0:
  464. return torch.zeros((0, 2), device=input_ids.device)
  465. return torch.hstack([
  466. image_start_tokens[:valid_image_nums].unsqueeze(-1),
  467. image_end_tokens[:valid_image_nums].unsqueeze(-1),
  468. ])
  469. def _parse_and_validate_inputs(
  470. self,
  471. input_ids: torch.Tensor,
  472. **kwargs: object,
  473. ) -> Optional[MiniCPMVImageInputs]:
  474. pixel_values = kwargs.pop("pixel_values", [])
  475. tgt_sizes = kwargs.pop("tgt_sizes", [])
  476. if not isinstance(pixel_values, (torch.Tensor, list)):
  477. raise ValueError("Incorrect type of pixel values. "
  478. f"Got type: {type(pixel_values)}")
  479. if not isinstance(tgt_sizes, (torch.Tensor, list)):
  480. raise ValueError("Incorrect type of target sizes. "
  481. f"Got type: {type(tgt_sizes)}")
  482. if len(pixel_values) != len(tgt_sizes):
  483. raise ValueError("Inconsistent batch lengths, found: "
  484. f"{len(pixel_values)} vs. {len(tgt_sizes)}")
  485. pixel_values_flat: List[torch.Tensor] = []
  486. tgt_sizes_flat: List[torch.Tensor] = []
  487. for b in range(len(pixel_values)):
  488. pixel_values_flat += pixel_values[b]
  489. tgt_sizes_flat += tgt_sizes[b]
  490. # NOTE: Input IDs does not contain image tokens during memory profiling,
  491. # so we allow it to be empty
  492. if len(pixel_values_flat) != len(tgt_sizes_flat):
  493. raise ValueError("Inconsistent flattened lengths, found: "
  494. f"{len(pixel_values_flat)} vs. "
  495. f"{len(tgt_sizes_flat)}")
  496. if len(pixel_values_flat) == 0:
  497. return None
  498. return MiniCPMVImageInputs(
  499. image_bounds=self._get_image_bounds(input_ids),
  500. pixel_values=pixel_values_flat,
  501. tgt_sizes=torch.stack(tgt_sizes_flat),
  502. )
  503. def forward(
  504. self,
  505. input_ids: torch.Tensor,
  506. positions: torch.Tensor,
  507. kv_caches: List[torch.Tensor],
  508. attn_metadata: AttentionMetadata,
  509. intermediate_tensors: Optional[IntermediateTensors] = None,
  510. **kwargs: Any,
  511. ) -> torch.Tensor:
  512. image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
  513. vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
  514. output = self.llm(
  515. input_ids=None,
  516. positions=positions,
  517. kv_caches=kv_caches,
  518. attn_metadata=attn_metadata,
  519. intermediate_tensors=intermediate_tensors,
  520. inputs_embeds=vlm_embeddings,
  521. )
  522. return output
  523. def compute_logits(self, hidden_states: torch.Tensor,
  524. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  525. logits = self.logits_processor(self.lm_head, hidden_states,
  526. sampling_metadata)
  527. return logits
  528. def sample(
  529. self,
  530. logits: torch.Tensor,
  531. sampling_metadata: SamplingMetadata,
  532. ) -> Optional[SamplerOutput]:
  533. next_tokens = self.sampler(logits, sampling_metadata)
  534. return next_tokens
  535. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  536. stacked_params_mapping = [
  537. # (param_name, shard_name, shard_id)
  538. ("qkv_proj", "q_proj", "q"),
  539. ("qkv_proj", "k_proj", "k"),
  540. ("qkv_proj", "v_proj", "v"),
  541. ("gate_up_proj", "gate_proj", 0),
  542. ("gate_up_proj", "up_proj", 1),
  543. ]
  544. params_dict = dict(self.named_parameters())
  545. for name, loaded_weight in weights:
  546. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  547. if key_to_modify in name:
  548. name = name.replace(key_to_modify, new_key)
  549. if "rotary_emb.inv_freq" in name:
  550. continue
  551. if ("rotary_emb.cos_cached" in name
  552. or "rotary_emb.sin_cached" in name):
  553. # Models trained using ColossalAI may include these tensors in
  554. # the checkpoint. Skip them.
  555. continue
  556. use_default_weight_loading = False
  557. if self.is_default_weight_loading(name):
  558. use_default_weight_loading = True
  559. else:
  560. for param_name, weight_name, shard_id in stacked_params_mapping:
  561. if weight_name not in name:
  562. continue
  563. param = params_dict[name.replace(weight_name, param_name)]
  564. weight_loader = param.weight_loader
  565. weight_loader(param, loaded_weight, shard_id)
  566. break
  567. else:
  568. use_default_weight_loading = True
  569. if use_default_weight_loading:
  570. param = params_dict[name]
  571. weight_loader = getattr(param, "weight_loader",
  572. default_weight_loader)
  573. weight_loader(param, loaded_weight)
  574. def init_llm(
  575. self,
  576. config: PretrainedConfig,
  577. cache_config: Optional[CacheConfig] = None,
  578. quant_config: Optional[QuantizationConfig] = None,
  579. ) -> nn.Module:
  580. raise NotImplementedError
  581. def init_vision_module(self) -> nn.Module:
  582. raise NotImplementedError
  583. def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
  584. raise NotImplementedError
  585. def get_vision_embedding(
  586. self,
  587. pixel_values: List[torch.Tensor],
  588. patch_attn_mask: Optional[torch.Tensor] = None,
  589. tgt_sizes: Optional[torch.Tensor] = None,
  590. ) -> torch.Tensor:
  591. raise NotImplementedError
  592. def get_vision_hidden_states(self,
  593. data: MiniCPMVImageInputs) -> torch.Tensor:
  594. raise NotImplementedError
  595. def is_default_weight_loading(self, name: str) -> bool:
  596. raise NotImplementedError
  597. class MiniCPMV2(MiniCPMVBaseModel):
  598. def __init__(
  599. self,
  600. config: PretrainedConfig,
  601. multimodal_config: MultiModalConfig,
  602. cache_config: Optional[CacheConfig] = None,
  603. quant_config: Optional[QuantizationConfig] = None,
  604. ):
  605. super().__init__(config, multimodal_config, cache_config, quant_config)
  606. assert self.version == (2, 0)
  607. def init_llm(
  608. self,
  609. config: PretrainedConfig,
  610. cache_config: Optional[CacheConfig] = None,
  611. quant_config: Optional[QuantizationConfig] = None,
  612. ) -> nn.Module:
  613. return MiniCPMModel(config,
  614. cache_config=cache_config,
  615. quant_config=quant_config)
  616. def init_vision_module(self) -> nn.Module:
  617. # TODO :refactor this vision model
  618. try:
  619. import timm
  620. except ImportError:
  621. raise ImportError("Please install timm==0.9.10") from ImportError
  622. with set_default_torch_dtype(torch.float16):
  623. model = timm.create_model(
  624. "vit_so400m_patch14_siglip_384.webli",
  625. pretrained=False,
  626. num_classes=0,
  627. dynamic_img_size=True,
  628. dynamic_img_pad=True,
  629. )
  630. if (isinstance(model, timm.models.VisionTransformer)
  631. and model.attn_pool is not None):
  632. model.attn_pool = torch.nn.Identity()
  633. if self.config.drop_vision_last_layer:
  634. model.blocks = model.blocks[:-1]
  635. return model
  636. def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
  637. with set_default_torch_dtype(torch.float16):
  638. resampler = Resampler2(
  639. embed_dim=embed_dim,
  640. num_heads=embed_dim // 128,
  641. grid_size=int(math.sqrt(self.config.query_num)),
  642. kv_dim=vision_dim,
  643. adaptive=True,
  644. )
  645. return resampler
  646. def get_vision_embedding(
  647. self,
  648. pixel_values: List[torch.Tensor],
  649. patch_attn_mask: Optional[torch.Tensor] = None,
  650. tgt_sizes: Optional[torch.Tensor] = None,
  651. ) -> torch.Tensor:
  652. res = []
  653. dtype = self.vpm.pos_embed.data.dtype
  654. for pixel_value in pixel_values:
  655. H, W = pixel_value[0].shape[-2:]
  656. tgt_size = (
  657. math.ceil(H / self.vpm.patch_embed.patch_size[0]),
  658. math.ceil(W / self.vpm.patch_embed.patch_size[0]),
  659. )
  660. vision_embedding = self.vpm.forward_features(
  661. pixel_value.unsqueeze(0).type(dtype))
  662. if (hasattr(self.vpm, "num_prefix_tokens")
  663. and self.vpm.num_prefix_tokens > 0):
  664. vision_embedding = vision_embedding[:, self.vpm.
  665. num_prefix_tokens:]
  666. res.append(self.resampler(vision_embedding, tgt_size))
  667. return torch.vstack(res)
  668. def get_vision_hidden_states(self,
  669. data: MiniCPMVImageInputs) -> torch.Tensor:
  670. pixel_values = data["pixel_values"]
  671. return self.get_vision_embedding(pixel_values)
  672. def is_default_weight_loading(self, name: str) -> bool:
  673. return "resampler" in name or "vpm" in name
  674. class MiniCPMV2_5(MiniCPMVBaseModel):
  675. def __init__(
  676. self,
  677. config: PretrainedConfig,
  678. multimodal_config: MultiModalConfig,
  679. cache_config: Optional[CacheConfig] = None,
  680. quant_config: Optional[QuantizationConfig] = None,
  681. ):
  682. super().__init__(config, multimodal_config, cache_config, quant_config)
  683. assert self.version == (2, 5)
  684. def init_llm(
  685. self,
  686. config: PretrainedConfig,
  687. cache_config: Optional[CacheConfig] = None,
  688. quant_config: Optional[QuantizationConfig] = None,
  689. ) -> nn.Module:
  690. return LlamaModel(config,
  691. cache_config=cache_config,
  692. quant_config=quant_config)
  693. def init_vision_module(self) -> nn.Module:
  694. model = Idefics2VisionTransformer(self.config.vision_config)
  695. if self.config.drop_vision_last_layer:
  696. model.encoder.layers = model.encoder.layers[:-1]
  697. return model
  698. def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
  699. with set_default_torch_dtype(torch.float16):
  700. resampler = Resampler2_5(
  701. num_queries=self.config.query_num,
  702. embed_dim=embed_dim,
  703. num_heads=embed_dim // 128,
  704. kv_dim=vision_dim,
  705. )
  706. return resampler
  707. def get_vision_embedding(
  708. self,
  709. pixel_values: List[torch.Tensor],
  710. patch_attn_mask: Optional[torch.Tensor] = None,
  711. tgt_sizes: Optional[torch.Tensor] = None,
  712. ) -> torch.Tensor:
  713. vision_embedding = self.vpm(pixel_values,
  714. patch_attention_mask=patch_attn_mask)
  715. vision_embedding = self.resampler(vision_embedding, tgt_sizes)
  716. return vision_embedding
  717. def get_vision_hidden_states(self,
  718. data: MiniCPMVImageInputs) -> torch.Tensor:
  719. pixel_values = data["pixel_values"]
  720. tgt_sizes = data["tgt_sizes"]
  721. device = self.vpm.embeddings.position_embedding.weight.device
  722. dtype = self.vpm.embeddings.position_embedding.weight.dtype
  723. all_pixel_values_lst = [
  724. i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
  725. ]
  726. max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
  727. assert isinstance(max_patches, int)
  728. all_pixel_values = torch.nn.utils.rnn.pad_sequence(
  729. all_pixel_values_lst, batch_first=True, padding_value=0.0)
  730. B, L, _ = all_pixel_values.shape
  731. all_pixel_values = all_pixel_values.permute(0, 2,
  732. 1).reshape(B, 3, -1, L)
  733. patch_attn_mask = torch.zeros((B, 1, max_patches),
  734. dtype=torch.bool,
  735. device=device)
  736. for i in range(B):
  737. patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
  738. return self.get_vision_embedding(all_pixel_values.type(dtype),
  739. patch_attn_mask, tgt_sizes)
  740. def is_default_weight_loading(self, name: str) -> bool:
  741. return "resampler" in name
  742. # NOTE: Currently, information about this model is unavailable. We are
  743. # temporarily using `MiniCPMVQwen2` as it's name. The name may need
  744. # to be modified in the future.
  745. class MiniCPMVQwen2(MiniCPMVBaseModel):
  746. def __init__(
  747. self,
  748. config: PretrainedConfig,
  749. multimodal_config: MultiModalConfig,
  750. cache_config: Optional[CacheConfig] = None,
  751. quant_config: Optional[QuantizationConfig] = None,
  752. ):
  753. super().__init__(config, multimodal_config, cache_config, quant_config)
  754. def init_llm(
  755. self,
  756. config: PretrainedConfig,
  757. cache_config: Optional[CacheConfig] = None,
  758. quant_config: Optional[QuantizationConfig] = None,
  759. ) -> nn.Module:
  760. return Qwen2Model(config,
  761. cache_config=cache_config,
  762. quant_config=quant_config)
  763. def init_vision_module(self) -> nn.Module:
  764. # A custom version of SiglipVisionTransformer, won't work with TP
  765. from aphrodite.modeling.models.na_vit import SiglipVisionTransformer
  766. if self.config._attn_implementation == "flash_attention_2":
  767. self.config.vision_config._attn_implementation = "flash_attention_2"
  768. else:
  769. # not support sdpa
  770. self.config.vision_config._attn_implementation = "eager"
  771. model = SiglipVisionTransformer(self.config.vision_config)
  772. if self.config.drop_vision_last_layer:
  773. model.encoder.layers = model.encoder.layers[:-1]
  774. return model
  775. def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
  776. with set_default_torch_dtype(torch.float16):
  777. resampler = Resampler2_5(
  778. num_queries=self.config.query_num,
  779. embed_dim=embed_dim,
  780. num_heads=embed_dim // 128,
  781. kv_dim=vision_dim,
  782. )
  783. return resampler
  784. def get_vision_embedding(
  785. self,
  786. pixel_values: List[torch.Tensor],
  787. patch_attn_mask: Optional[torch.Tensor] = None,
  788. tgt_sizes: Optional[torch.Tensor] = None,
  789. ) -> torch.Tensor:
  790. vision_embedding = self.vpm(
  791. pixel_values,
  792. patch_attention_mask=patch_attn_mask,
  793. tgt_sizes=tgt_sizes,
  794. ).last_hidden_state
  795. return vision_embedding
  796. def get_vision_hidden_states(self,
  797. data: MiniCPMVImageInputs) -> torch.Tensor:
  798. pixel_values = data["pixel_values"]
  799. tgt_sizes = data["tgt_sizes"]
  800. device = self.vpm.embeddings.position_embedding.weight.device
  801. dtype = self.vpm.embeddings.position_embedding.weight.dtype
  802. all_pixel_values_lst = [
  803. i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
  804. ]
  805. max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
  806. assert isinstance(max_patches, int)
  807. all_pixel_values = torch.nn.utils.rnn.pad_sequence(
  808. all_pixel_values_lst, batch_first=True, padding_value=0.0)
  809. B, L, _ = all_pixel_values.shape
  810. all_pixel_values = all_pixel_values.permute(0, 2,
  811. 1).reshape(B, 3, -1, L)
  812. patch_attn_mask = torch.zeros((B, 1, max_patches),
  813. dtype=torch.bool,
  814. device=device)
  815. for i in range(B):
  816. patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
  817. vision_embedding = self.vpm(
  818. all_pixel_values.type(dtype),
  819. patch_attention_mask=patch_attn_mask,
  820. tgt_sizes=tgt_sizes,
  821. ).last_hidden_state
  822. return self.resampler(vision_embedding, tgt_sizes)
  823. def is_default_weight_loading(self, name: str) -> bool:
  824. return "resampler" in name or "vpm" in name
  825. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  826. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
  827. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
  828. @INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv)
  829. class MiniCPMV(MiniCPMVBaseModel):
  830. """
  831. Different versions of MiniCPMV use different visual encoders and LLMs,
  832. which is not conducive to the current integration logic of LoRA and
  833. bitsandbytes in aphrodite. Therefore, it is necessary to separate them.
  834. """
  835. def __new__(
  836. cls,
  837. config: PretrainedConfig,
  838. multimodal_config: MultiModalConfig,
  839. cache_config: Optional[CacheConfig] = None,
  840. quant_config: Optional[QuantizationConfig] = None,
  841. ):
  842. if not hasattr(config, "version"):
  843. if config.hidden_size == 2304 and config.query_num == 64:
  844. version = (2, 0)
  845. else:
  846. version = (2, 5)
  847. else:
  848. version = str(config.version).split(".")
  849. version = tuple([int(x) for x in version])
  850. # Dispatch class based on version
  851. if version == (2, 0):
  852. instance_class = MiniCPMV2
  853. elif version == (2, 5):
  854. instance_class = MiniCPMV2_5
  855. else:
  856. instance_class = MiniCPMVQwen2
  857. return instance_class(config, multimodal_config, cache_config,
  858. quant_config)