minicpmv.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only MiniCPM-V model compatible with HuggingFace weights."""
  24. import math
  25. import re
  26. from functools import partial
  27. from typing import (Any, Callable, Iterable, List, Optional, Tuple, TypedDict,
  28. Union)
  29. import numpy as np
  30. import torch
  31. import torch.nn.functional as F
  32. import torch.types
  33. from PIL import Image
  34. from torch import nn
  35. from torch.nn.init import trunc_normal_
  36. from transformers.configuration_utils import PretrainedConfig
  37. from aphrodite.attention import AttentionMetadata
  38. from aphrodite.common.config import CacheConfig, MultiModalConfig
  39. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  40. SequenceData)
  41. from aphrodite.common.utils import progress_bar
  42. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  43. from aphrodite.modeling.layers.linear import ReplicatedLinear
  44. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  45. from aphrodite.modeling.layers.sampler import Sampler
  46. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  47. from aphrodite.modeling.model_loader.utils import set_default_torch_dtype
  48. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  49. from aphrodite.modeling.models.interfaces import SupportsMultiModal
  50. from aphrodite.modeling.models.llama import LlamaModel
  51. from aphrodite.modeling.models.minicpm import MiniCPMModel
  52. from aphrodite.modeling.models.qwen2 import Qwen2Model
  53. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  54. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  55. from aphrodite.multimodal.image import (cached_get_image_processor,
  56. cached_get_tokenizer)
  57. from aphrodite.quantization.base_config import QuantizationConfig
  58. from .idefics2_vision_model import Idefics2VisionTransformer
  59. _KEYS_TO_MODIFY_MAPPING = {
  60. "llm.lm_head": "lm_head",
  61. "llm.model": "llm",
  62. }
  63. class MiniCPMVImagePixelInputs(TypedDict):
  64. pixel_values: List[torch.Tensor]
  65. """
  66. Shape: `(batch_size * num_images, num_channels, height, width)`
  67. Note that the image size may vary, so we pass it as a list
  68. instead of a batched tensor.
  69. """
  70. image_bounds: torch.Tensor
  71. """
  72. Shape: `(batch_size * num_images, 2)`
  73. This should be in `(start, stop)` format.
  74. """
  75. tgt_sizes: torch.Tensor
  76. """
  77. Shape: `(batch_size * num_images, 2)`
  78. This should be in `(height, width)` format.
  79. """
  80. MiniCPMVImageInputs = MiniCPMVImagePixelInputs
  81. DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
  82. def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor):
  83. # abs_pos: L, C
  84. # tgt_size: (H, W)
  85. # return: M, C
  86. src_size = int(math.sqrt(abs_pos.size(0)))
  87. # tgt_size = int(math.sqrt(tgt_size))
  88. dtype = abs_pos.dtype
  89. return (F.interpolate(
  90. abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
  91. size=(tgt_size[0], tgt_size[1]),
  92. mode="bicubic",
  93. align_corners=False,
  94. ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype))
  95. # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
  96. def get_2d_sincos_pos_embed(
  97. embed_dim: int,
  98. grid_size: Union[int, Tuple[int, int]],
  99. cls_token: bool = False,
  100. version: Tuple[int, int] = (2, 0),
  101. ):
  102. """
  103. grid_size: int of the grid height and width
  104. return:
  105. pos_embed: [grid_size*grid_size, embed_dim] or
  106. [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
  107. """
  108. if isinstance(grid_size, int):
  109. grid_h_size, grid_w_size = grid_size, grid_size
  110. else:
  111. grid_h_size, grid_w_size = grid_size[0], grid_size[1]
  112. grid_h = np.arange(grid_h_size, dtype=np.float32)
  113. grid_w = np.arange(grid_w_size, dtype=np.float32)
  114. grid = np.meshgrid(grid_w, grid_h) # here w goes first
  115. grid = np.stack(grid, axis=0)
  116. if version == (2, 0):
  117. grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
  118. pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
  119. if cls_token:
  120. pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed],
  121. axis=0)
  122. else:
  123. pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
  124. return pos_embed
  125. def get_2d_sincos_pos_embed_from_grid(embed_dim: int,
  126. grid: np.ndarray,
  127. version: Tuple[int, int] = (2, 0)):
  128. assert embed_dim % 2 == 0
  129. # use half of dimensions to encode grid_h
  130. emb_h = get_1d_sincos_pos_embed_from_grid(
  131. embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2)
  132. emb_w = get_1d_sincos_pos_embed_from_grid(
  133. embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2)
  134. if version == (2, 0):
  135. emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
  136. else:
  137. emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
  138. return emb
  139. def get_1d_sincos_pos_embed_from_grid(embed_dim: int,
  140. pos: np.ndarray,
  141. version: Tuple[int, int] = (2, 0)):
  142. """
  143. embed_dim: output dimension for each position
  144. pos: a list of positions to be encoded: size (M,) / (H, W)
  145. out: (M, D) / (H, W, D)
  146. """
  147. assert embed_dim % 2 == 0
  148. omega = np.arange(embed_dim // 2, dtype=np.float32)
  149. omega /= embed_dim / 2.0
  150. omega = 1.0 / 10000**omega # (D/2,)
  151. if version == (2, 0):
  152. pos = pos.reshape(-1) # (M,)
  153. out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
  154. emb_sin = np.sin(out) # (M, D/2)
  155. emb_cos = np.cos(out) # (M, D/2)
  156. emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
  157. else:
  158. out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
  159. emb_sin = np.sin(out) # (H, W, D/2)
  160. emb_cos = np.cos(out) # (H, W, D/2)
  161. emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
  162. return emb
  163. class BaseResampler(nn.Module):
  164. """
  165. A 2D perceiver-resampler network with one cross attention layers by
  166. (grid_size**2) learnable queries and 2d sincos pos_emb
  167. Outputs:
  168. A tensor with the shape of (grid_size**2, embed_dim)
  169. """
  170. def __init__(
  171. self,
  172. num_queries: int,
  173. embed_dim: int,
  174. num_heads: int,
  175. kv_dim: Optional[int] = None,
  176. norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
  177. ) -> None:
  178. super().__init__()
  179. self.num_queries = num_queries
  180. self.embed_dim = embed_dim
  181. self.num_heads = num_heads
  182. self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
  183. trunc_normal_(self.query, std=0.02)
  184. if kv_dim is not None and kv_dim != embed_dim:
  185. self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
  186. else:
  187. # Maintain the same return value with ReplicatedLinear.forward
  188. self.kv_proj = lambda *args, **kwargs: (
  189. nn.Identity()(*args, **kwargs),
  190. None,
  191. )
  192. self.attn = nn.MultiheadAttention(embed_dim, num_heads)
  193. self.ln_q = norm_layer(embed_dim)
  194. self.ln_kv = norm_layer(embed_dim)
  195. self.ln_post = norm_layer(embed_dim)
  196. self.proj = nn.Parameter(
  197. (embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
  198. def _init_weights(self, m: nn.Module) -> None:
  199. if isinstance(m, nn.Linear):
  200. trunc_normal_(m.weight, std=0.02)
  201. if isinstance(m, nn.Linear) and m.bias is not None:
  202. nn.init.constant_(m.bias, 0)
  203. elif isinstance(m, nn.LayerNorm):
  204. nn.init.constant_(m.bias, 0)
  205. nn.init.constant_(m.weight, 1.0)
  206. def _repeat(self, query, N: int):
  207. return query.unsqueeze(1).repeat(1, N, 1)
  208. class Resampler2(BaseResampler):
  209. def __init__(
  210. self,
  211. grid_size: int,
  212. embed_dim: int,
  213. num_heads: int,
  214. kv_dim: Optional[int] = None,
  215. norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
  216. adaptive: bool = False,
  217. ) -> None:
  218. super().__init__(grid_size**2, embed_dim, num_heads, kv_dim,
  219. norm_layer)
  220. self.adaptive = adaptive
  221. pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
  222. grid_size,
  223. version=(2, 0))
  224. self.pos_embed = nn.Parameter(
  225. torch.from_numpy(pos_embed_arr).float()).requires_grad_(False)
  226. self.apply(self._init_weights)
  227. def forward(
  228. self,
  229. x: torch.Tensor,
  230. tgt_sizes: torch.Tensor,
  231. attn_mask: Optional[torch.Tensor] = None,
  232. ):
  233. if self.adaptive:
  234. pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
  235. tgt_sizes,
  236. version=(2, 0))
  237. pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device,
  238. dtype=x.dtype)
  239. else:
  240. pos_embed = get_abs_pos(self.pos_embed, tgt_sizes)
  241. x, _ = self.kv_proj(x)
  242. x = self.ln_kv(x).permute(1, 0, 2)
  243. N = x.shape[1]
  244. q = self.ln_q(self.query)
  245. out = self.attn(
  246. self._repeat(q, N) + self.pos_embed.unsqueeze(1),
  247. x + pos_embed.unsqueeze(1),
  248. x,
  249. attn_mask=attn_mask,
  250. )[0]
  251. x = out.permute(1, 0, 2)
  252. x = self.ln_post(x)
  253. x = x @ self.proj
  254. return x
  255. class Resampler2_5(BaseResampler):
  256. def __init__(
  257. self,
  258. num_queries: int,
  259. embed_dim: int,
  260. num_heads: int,
  261. kv_dim: Optional[int] = None,
  262. norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
  263. max_size: Tuple[int, int] = (70, 70),
  264. ) -> None:
  265. super().__init__(num_queries, embed_dim, num_heads, kv_dim, norm_layer)
  266. self.max_size = max_size
  267. self._set_2d_pos_cache(self.max_size)
  268. self.apply(self._init_weights)
  269. def _set_2d_pos_cache(self,
  270. max_size: Tuple[int, int],
  271. device: torch.types.Device = "cpu") -> None:
  272. pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
  273. max_size,
  274. version=(2, 5))
  275. pos_embed = torch.from_numpy(pos_embed_arr).float().to(device)
  276. self.register_buffer("pos_embed", pos_embed, persistent=False)
  277. def _adjust_pos_cache(self, tgt_sizes: torch.Tensor,
  278. device: torch.types.Device) -> None:
  279. max_h = tgt_sizes[:, 0].max().item()
  280. max_w = tgt_sizes[:, 1].max().item()
  281. assert isinstance(max_h, int) and isinstance(max_w, int)
  282. if max_h > self.max_size[0] or max_w > self.max_size[1]:
  283. self.max_size = (
  284. max(max_h, self.max_size[0]),
  285. max(max_w, self.max_size[1]),
  286. )
  287. self._set_2d_pos_cache(self.max_size, device)
  288. def forward(self, x: torch.Tensor,
  289. tgt_sizes: torch.Tensor) -> torch.Tensor:
  290. assert x.shape[0] == tgt_sizes.shape[0]
  291. bs = x.shape[0]
  292. device = x.device
  293. dtype = x.dtype
  294. patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
  295. self._adjust_pos_cache(tgt_sizes, device=device)
  296. max_patch_len = patch_len.max().item()
  297. assert isinstance(max_patch_len, int)
  298. key_padding_mask = torch.zeros((bs, max_patch_len),
  299. dtype=torch.bool,
  300. device=device)
  301. pos_embed = []
  302. for i in range(bs):
  303. tgt_h, tgt_w = tgt_sizes[i].tolist()
  304. pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape(
  305. (tgt_h * tgt_w, -1)).to(dtype)) # patches * D
  306. key_padding_mask[i, patch_len[i]:] = True
  307. pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed,
  308. batch_first=True,
  309. padding_value=0.0).permute(
  310. 1, 0,
  311. 2) # BLD => L * B * D
  312. x, _ = self.kv_proj(x) # B * L * D
  313. x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
  314. q = self.ln_q(self.query) # Q * D
  315. out = self.attn(
  316. self._repeat(q, bs), # Q * B * D
  317. x + pos_embed, # L * B * D + L * B * D
  318. x,
  319. key_padding_mask=key_padding_mask,
  320. )[0]
  321. # out: Q * B * D
  322. x = out.permute(1, 0, 2) # B * Q * D
  323. x = self.ln_post(x)
  324. x = x @ self.proj
  325. return x
  326. def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
  327. version_float = getattr(config, "version", None)
  328. # The old configs do not include version number
  329. # TODO: Remove this after the HF repos are updated
  330. if version_float is None:
  331. if config.hidden_size == 2304 and config.query_num == 64:
  332. return (2, 0)
  333. return (2, 5)
  334. version_str = str(version_float)
  335. return tuple(int(x) for x in version_str.split("."))
  336. def get_max_minicpmv_image_tokens(ctx: InputContext):
  337. hf_config = ctx.get_hf_config(PretrainedConfig)
  338. return getattr(hf_config, "query_num", 64)
  339. def dummy_seq_data_for_minicpmv(seq_len: int):
  340. token_ids = [0] * seq_len
  341. return SequenceData(token_ids)
  342. def dummy_image_for_minicpmv(hf_config: PretrainedConfig):
  343. width = height = hf_config.image_size
  344. image = Image.new("RGB", (width, height), color=0)
  345. return {"image": image}
  346. def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int):
  347. hf_config = ctx.get_hf_config(PretrainedConfig)
  348. seq_data = dummy_seq_data_for_minicpmv(seq_len)
  349. mm_data = dummy_image_for_minicpmv(hf_config)
  350. return seq_data, mm_data
  351. def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
  352. multi_modal_data = llm_inputs.get("multi_modal_data")
  353. if multi_modal_data is None or "image" not in multi_modal_data:
  354. return llm_inputs
  355. model_config = ctx.model_config
  356. version = get_version_by_config(model_config.hf_config)
  357. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  358. trust_remote_code=True)
  359. image_processor = cached_get_image_processor(model_config.tokenizer)
  360. def get_placeholder(image_size: Tuple[int, int], num_image: int):
  361. if version == (2, 0) or version == (2, 5):
  362. return image_processor. \
  363. get_slice_image_placeholder(image_size)
  364. return image_processor. \
  365. get_slice_image_placeholder(image_size, num_image)
  366. prompt = llm_inputs.get("prompt")
  367. if prompt is None:
  368. token_ids = llm_inputs.get("prompt_token_ids")
  369. prompt = tokenizer.decode(token_ids)
  370. pattern = "(<image>./</image>)"
  371. images = multi_modal_data["image"]
  372. if isinstance(images, Image.Image):
  373. images = [images]
  374. image_tags = re.findall(pattern, prompt)
  375. if len(image_tags) == 0:
  376. new_token_ids = token_ids
  377. new_prompt = prompt
  378. else:
  379. text_chunks = prompt.split(pattern)
  380. new_prompt_chunks: List[str] = []
  381. for i in range(len(images)):
  382. new_prompt_chunks += [
  383. text_chunks[i],
  384. get_placeholder(images[i].size, i)
  385. ]
  386. new_prompt_chunks.append(text_chunks[-1])
  387. new_prompt = "".join(new_prompt_chunks)
  388. new_token_ids = tokenizer.encode(new_prompt)
  389. llm_inputs = LLMInputs(
  390. prompt_token_ids=new_token_ids,
  391. prompt=new_prompt,
  392. multi_modal_data=multi_modal_data,
  393. )
  394. return llm_inputs
  395. class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
  396. """
  397. The abstract class of MiniCPMV can only be inherited, but cannot be
  398. instantiated.
  399. """
  400. def __init__(
  401. self,
  402. config: PretrainedConfig,
  403. multimodal_config: MultiModalConfig,
  404. cache_config: Optional[CacheConfig] = None,
  405. quant_config: Optional[QuantizationConfig] = None,
  406. ):
  407. super().__init__()
  408. self.config = config
  409. self.multimodal_config = multimodal_config
  410. self.version = get_version_by_config(self.config)
  411. self.llm = self.init_llm(config, cache_config, quant_config)
  412. self.vpm = self.init_vision_module()
  413. param_dtype = torch.get_default_dtype()
  414. self.vpm.to(dtype=param_dtype)
  415. self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
  416. self.vpm.embeddings.embed_dim)
  417. self.embed_dim = self.config.hidden_size
  418. self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
  419. self.resampler.to(device="cuda", dtype=param_dtype)
  420. self.lm_head = ParallelLMHead(config.vocab_size,
  421. config.hidden_size,
  422. quant_config=quant_config)
  423. self.logits_processor = LogitsProcessor(config.vocab_size)
  424. self.sampler = Sampler()
  425. def get_embedding(
  426. self,
  427. input_ids: torch.Tensor,
  428. image_inputs: Optional[MiniCPMVImageInputs],
  429. ) -> Tuple[torch.Tensor, torch.Tensor]:
  430. vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
  431. if hasattr(self.config, "scale_emb"):
  432. vlm_embedding *= self.config.scale_emb
  433. if image_inputs is None: # No image
  434. vision_hidden_states = torch.tensor([], device=input_ids.device)
  435. else:
  436. vision_hidden_states = self.get_vision_hidden_states(image_inputs)
  437. # See NOTE in _parse_and_validate_inputs
  438. image_bounds = image_inputs["image_bounds"]
  439. if len(image_bounds) > 0:
  440. image_indices = torch.stack([
  441. torch.arange(start, end, dtype=torch.long)
  442. for start, end in image_bounds.tolist()
  443. ]).to(vlm_embedding.device)
  444. vlm_embedding.scatter_(
  445. 0,
  446. image_indices.view(-1, 1).repeat(1,
  447. vlm_embedding.shape[-1]),
  448. vision_hidden_states.view(-1,
  449. vision_hidden_states.shape[-1]),
  450. )
  451. return vlm_embedding, vision_hidden_states
  452. def _get_image_bounds(self, input_ids: torch.Tensor) -> torch.Tensor:
  453. tokenizer = cached_get_tokenizer(self.config._name_or_path,
  454. trust_remote_code=True)
  455. start_cond = input_ids == tokenizer.im_start_id
  456. end_cond = input_ids == tokenizer.im_end_id
  457. if hasattr(tokenizer, "slice_start_id"):
  458. start_cond |= (input_ids == tokenizer.slice_start_id)
  459. end_cond |= (input_ids == tokenizer.slice_end_id)
  460. image_start_tokens, = torch.where(start_cond)
  461. image_start_tokens += 1
  462. image_end_tokens, = torch.where(end_cond)
  463. valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
  464. if valid_image_nums == 0:
  465. return torch.zeros((0, 2), device=input_ids.device)
  466. return torch.hstack([
  467. image_start_tokens[:valid_image_nums].unsqueeze(-1),
  468. image_end_tokens[:valid_image_nums].unsqueeze(-1),
  469. ])
  470. def _parse_and_validate_inputs(
  471. self,
  472. input_ids: torch.Tensor,
  473. **kwargs: object,
  474. ) -> Optional[MiniCPMVImageInputs]:
  475. pixel_values = kwargs.pop("pixel_values", [])
  476. tgt_sizes = kwargs.pop("tgt_sizes", [])
  477. if not isinstance(pixel_values, (torch.Tensor, list)):
  478. raise ValueError("Incorrect type of pixel values. "
  479. f"Got type: {type(pixel_values)}")
  480. if not isinstance(tgt_sizes, (torch.Tensor, list)):
  481. raise ValueError("Incorrect type of target sizes. "
  482. f"Got type: {type(tgt_sizes)}")
  483. if len(pixel_values) != len(tgt_sizes):
  484. raise ValueError("Inconsistent batch lengths, found: "
  485. f"{len(pixel_values)} vs. {len(tgt_sizes)}")
  486. pixel_values_flat: List[torch.Tensor] = []
  487. tgt_sizes_flat: List[torch.Tensor] = []
  488. for b in range(len(pixel_values)):
  489. pixel_values_flat += pixel_values[b]
  490. tgt_sizes_flat += tgt_sizes[b]
  491. # NOTE: Input IDs does not contain image tokens during memory profiling,
  492. # so we allow it to be empty
  493. if len(pixel_values_flat) != len(tgt_sizes_flat):
  494. raise ValueError("Inconsistent flattened lengths, found: "
  495. f"{len(pixel_values_flat)} vs. "
  496. f"{len(tgt_sizes_flat)}")
  497. if len(pixel_values_flat) == 0:
  498. return None
  499. return MiniCPMVImageInputs(
  500. image_bounds=self._get_image_bounds(input_ids),
  501. pixel_values=pixel_values_flat,
  502. tgt_sizes=torch.stack(tgt_sizes_flat),
  503. )
  504. def forward(
  505. self,
  506. input_ids: torch.Tensor,
  507. positions: torch.Tensor,
  508. kv_caches: List[torch.Tensor],
  509. attn_metadata: AttentionMetadata,
  510. intermediate_tensors: Optional[IntermediateTensors] = None,
  511. **kwargs: Any,
  512. ) -> torch.Tensor:
  513. image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
  514. vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
  515. output = self.llm(
  516. input_ids=None,
  517. positions=positions,
  518. kv_caches=kv_caches,
  519. attn_metadata=attn_metadata,
  520. intermediate_tensors=intermediate_tensors,
  521. inputs_embeds=vlm_embeddings,
  522. )
  523. return output
  524. def compute_logits(
  525. self,
  526. hidden_states: torch.Tensor,
  527. sampling_metadata: SamplingMetadata,
  528. ) -> Optional[torch.Tensor]:
  529. logits = self.logits_processor(self.lm_head, hidden_states,
  530. sampling_metadata)
  531. return logits
  532. def sample(
  533. self,
  534. logits: torch.Tensor,
  535. sampling_metadata: SamplingMetadata,
  536. ) -> Optional[SamplerOutput]:
  537. next_tokens = self.sampler(logits, sampling_metadata)
  538. return next_tokens
  539. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  540. stacked_params_mapping = [
  541. # (param_name, shard_name, shard_id)
  542. ("qkv_proj", "q_proj", "q"),
  543. ("qkv_proj", "k_proj", "k"),
  544. ("qkv_proj", "v_proj", "v"),
  545. ("gate_up_proj", "gate_proj", 0),
  546. ("gate_up_proj", "up_proj", 1),
  547. ]
  548. params_dict = dict(self.named_parameters())
  549. weights_list = list(weights)
  550. for name, loaded_weight in progress_bar(weights_list,
  551. desc="Loading modules..."):
  552. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  553. if key_to_modify in name:
  554. name = name.replace(key_to_modify, new_key)
  555. if "rotary_emb.inv_freq" in name:
  556. continue
  557. if ("rotary_emb.cos_cached" in name
  558. or "rotary_emb.sin_cached" in name):
  559. # Models trained using ColossalAI may include these tensors in
  560. # the checkpoint. Skip them.
  561. continue
  562. use_default_weight_loading = False
  563. if self.is_default_weight_loading(name):
  564. use_default_weight_loading = True
  565. else:
  566. for param_name, weight_name, shard_id in stacked_params_mapping:
  567. if weight_name not in name:
  568. continue
  569. param = params_dict[name.replace(weight_name, param_name)]
  570. weight_loader = param.weight_loader
  571. weight_loader(param, loaded_weight, shard_id)
  572. break
  573. else:
  574. use_default_weight_loading = True
  575. if use_default_weight_loading:
  576. param = params_dict[name]
  577. weight_loader = getattr(param, "weight_loader",
  578. default_weight_loader)
  579. weight_loader(param, loaded_weight)
  580. def init_llm(
  581. self,
  582. config: PretrainedConfig,
  583. cache_config: Optional[CacheConfig] = None,
  584. quant_config: Optional[QuantizationConfig] = None,
  585. ) -> nn.Module:
  586. raise NotImplementedError
  587. def init_vision_module(self) -> nn.Module:
  588. raise NotImplementedError
  589. def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
  590. raise NotImplementedError
  591. def get_vision_embedding(
  592. self,
  593. pixel_values: List[torch.Tensor],
  594. patch_attn_mask: Optional[torch.Tensor] = None,
  595. tgt_sizes: Optional[torch.Tensor] = None,
  596. ) -> torch.Tensor:
  597. raise NotImplementedError
  598. def get_vision_hidden_states(self,
  599. data: MiniCPMVImageInputs) -> torch.Tensor:
  600. raise NotImplementedError
  601. def is_default_weight_loading(self, name: str) -> bool:
  602. raise NotImplementedError
  603. class MiniCPMV2(MiniCPMVBaseModel):
  604. def __init__(
  605. self,
  606. config: PretrainedConfig,
  607. multimodal_config: MultiModalConfig,
  608. cache_config: Optional[CacheConfig] = None,
  609. quant_config: Optional[QuantizationConfig] = None,
  610. ):
  611. super().__init__(config, multimodal_config, cache_config, quant_config)
  612. assert self.version == (2, 0)
  613. def init_llm(
  614. self,
  615. config: PretrainedConfig,
  616. cache_config: Optional[CacheConfig] = None,
  617. quant_config: Optional[QuantizationConfig] = None,
  618. ) -> nn.Module:
  619. return MiniCPMModel(config,
  620. cache_config=cache_config,
  621. quant_config=quant_config)
  622. def init_vision_module(self) -> nn.Module:
  623. # TODO :refactor this vision model
  624. try:
  625. import timm
  626. except ImportError:
  627. raise ImportError("Please install timm==0.9.10") from ImportError
  628. with set_default_torch_dtype(torch.float16):
  629. model = timm.create_model(
  630. "vit_so400m_patch14_siglip_384.webli",
  631. pretrained=False,
  632. num_classes=0,
  633. dynamic_img_size=True,
  634. dynamic_img_pad=True,
  635. )
  636. if (isinstance(model, timm.models.VisionTransformer)
  637. and model.attn_pool is not None):
  638. model.attn_pool = torch.nn.Identity()
  639. if self.config.drop_vision_last_layer:
  640. model.blocks = model.blocks[:-1]
  641. return model
  642. def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
  643. with set_default_torch_dtype(torch.float16):
  644. resampler = Resampler2(
  645. embed_dim=embed_dim,
  646. num_heads=embed_dim // 128,
  647. grid_size=int(math.sqrt(self.config.query_num)),
  648. kv_dim=vision_dim,
  649. adaptive=True,
  650. )
  651. return resampler
  652. def get_vision_embedding(
  653. self,
  654. pixel_values: List[torch.Tensor],
  655. patch_attn_mask: Optional[torch.Tensor] = None,
  656. tgt_sizes: Optional[torch.Tensor] = None,
  657. ) -> torch.Tensor:
  658. res = []
  659. dtype = self.vpm.pos_embed.data.dtype
  660. for pixel_value in pixel_values:
  661. H, W = pixel_value[0].shape[-2:]
  662. tgt_size = (
  663. math.ceil(H / self.vpm.patch_embed.patch_size[0]),
  664. math.ceil(W / self.vpm.patch_embed.patch_size[0]),
  665. )
  666. vision_embedding = self.vpm.forward_features(
  667. pixel_value.unsqueeze(0).type(dtype))
  668. if (hasattr(self.vpm, "num_prefix_tokens")
  669. and self.vpm.num_prefix_tokens > 0):
  670. vision_embedding = vision_embedding[:, self.vpm.
  671. num_prefix_tokens:]
  672. res.append(self.resampler(vision_embedding, tgt_size))
  673. return torch.vstack(res)
  674. def get_vision_hidden_states(self,
  675. data: MiniCPMVImageInputs) -> torch.Tensor:
  676. pixel_values = data["pixel_values"]
  677. return self.get_vision_embedding(pixel_values)
  678. def is_default_weight_loading(self, name: str) -> bool:
  679. return "resampler" in name or "vpm" in name
  680. class MiniCPMV2_5(MiniCPMVBaseModel):
  681. def __init__(
  682. self,
  683. config: PretrainedConfig,
  684. multimodal_config: MultiModalConfig,
  685. cache_config: Optional[CacheConfig] = None,
  686. quant_config: Optional[QuantizationConfig] = None,
  687. ):
  688. super().__init__(config, multimodal_config, cache_config, quant_config)
  689. assert self.version == (2, 5)
  690. def init_llm(
  691. self,
  692. config: PretrainedConfig,
  693. cache_config: Optional[CacheConfig] = None,
  694. quant_config: Optional[QuantizationConfig] = None,
  695. ) -> nn.Module:
  696. return LlamaModel(config,
  697. cache_config=cache_config,
  698. quant_config=quant_config)
  699. def init_vision_module(self) -> nn.Module:
  700. model = Idefics2VisionTransformer(self.config.vision_config)
  701. if self.config.drop_vision_last_layer:
  702. model.encoder.layers = model.encoder.layers[:-1]
  703. return model
  704. def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
  705. with set_default_torch_dtype(torch.float16):
  706. resampler = Resampler2_5(
  707. num_queries=self.config.query_num,
  708. embed_dim=embed_dim,
  709. num_heads=embed_dim // 128,
  710. kv_dim=vision_dim,
  711. )
  712. return resampler
  713. def get_vision_embedding(
  714. self,
  715. pixel_values: List[torch.Tensor],
  716. patch_attn_mask: Optional[torch.Tensor] = None,
  717. tgt_sizes: Optional[torch.Tensor] = None,
  718. ) -> torch.Tensor:
  719. vision_embedding = self.vpm(pixel_values,
  720. patch_attention_mask=patch_attn_mask)
  721. vision_embedding = self.resampler(vision_embedding, tgt_sizes)
  722. return vision_embedding
  723. def get_vision_hidden_states(self,
  724. data: MiniCPMVImageInputs) -> torch.Tensor:
  725. pixel_values = data["pixel_values"]
  726. tgt_sizes = data["tgt_sizes"]
  727. device = self.vpm.embeddings.position_embedding.weight.device
  728. dtype = self.vpm.embeddings.position_embedding.weight.dtype
  729. all_pixel_values_lst = [
  730. i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
  731. ]
  732. max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
  733. assert isinstance(max_patches, int)
  734. all_pixel_values = torch.nn.utils.rnn.pad_sequence(
  735. all_pixel_values_lst, batch_first=True, padding_value=0.0)
  736. B, L, _ = all_pixel_values.shape
  737. all_pixel_values = all_pixel_values.permute(0, 2,
  738. 1).reshape(B, 3, -1, L)
  739. patch_attn_mask = torch.zeros((B, 1, max_patches),
  740. dtype=torch.bool,
  741. device=device)
  742. for i in range(B):
  743. patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
  744. return self.get_vision_embedding(all_pixel_values.type(dtype),
  745. patch_attn_mask, tgt_sizes)
  746. def is_default_weight_loading(self, name: str) -> bool:
  747. return "resampler" in name
  748. # NOTE: Currently, information about this model is unavailable. We are
  749. # temporarily using `MiniCPMVQwen2` as it's name. The name may need
  750. # to be modified in the future.
  751. class MiniCPMVQwen2(MiniCPMVBaseModel):
  752. def __init__(
  753. self,
  754. config: PretrainedConfig,
  755. multimodal_config: MultiModalConfig,
  756. cache_config: Optional[CacheConfig] = None,
  757. quant_config: Optional[QuantizationConfig] = None,
  758. ):
  759. super().__init__(config, multimodal_config, cache_config, quant_config)
  760. def init_llm(
  761. self,
  762. config: PretrainedConfig,
  763. cache_config: Optional[CacheConfig] = None,
  764. quant_config: Optional[QuantizationConfig] = None,
  765. ) -> nn.Module:
  766. return Qwen2Model(config,
  767. cache_config=cache_config,
  768. quant_config=quant_config)
  769. def init_vision_module(self) -> nn.Module:
  770. # A custom version of SiglipVisionTransformer, won't work with TP
  771. from aphrodite.modeling.models.na_vit import SiglipVisionTransformer
  772. if self.config._attn_implementation == "flash_attention_2":
  773. self.config.vision_config._attn_implementation = "flash_attention_2"
  774. else:
  775. # not support sdpa
  776. self.config.vision_config._attn_implementation = "eager"
  777. model = SiglipVisionTransformer(self.config.vision_config)
  778. if self.config.drop_vision_last_layer:
  779. model.encoder.layers = model.encoder.layers[:-1]
  780. return model
  781. def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
  782. with set_default_torch_dtype(torch.float16):
  783. resampler = Resampler2_5(
  784. num_queries=self.config.query_num,
  785. embed_dim=embed_dim,
  786. num_heads=embed_dim // 128,
  787. kv_dim=vision_dim,
  788. )
  789. return resampler
  790. def get_vision_embedding(
  791. self,
  792. pixel_values: List[torch.Tensor],
  793. patch_attn_mask: Optional[torch.Tensor] = None,
  794. tgt_sizes: Optional[torch.Tensor] = None,
  795. ) -> torch.Tensor:
  796. vision_embedding = self.vpm(
  797. pixel_values,
  798. patch_attention_mask=patch_attn_mask,
  799. tgt_sizes=tgt_sizes,
  800. ).last_hidden_state
  801. return vision_embedding
  802. def get_vision_hidden_states(self,
  803. data: MiniCPMVImageInputs) -> torch.Tensor:
  804. pixel_values = data["pixel_values"]
  805. tgt_sizes = data["tgt_sizes"]
  806. device = self.vpm.embeddings.position_embedding.weight.device
  807. dtype = self.vpm.embeddings.position_embedding.weight.dtype
  808. all_pixel_values_lst = [
  809. i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
  810. ]
  811. max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
  812. assert isinstance(max_patches, int)
  813. all_pixel_values = torch.nn.utils.rnn.pad_sequence(
  814. all_pixel_values_lst, batch_first=True, padding_value=0.0)
  815. B, L, _ = all_pixel_values.shape
  816. all_pixel_values = all_pixel_values.permute(0, 2,
  817. 1).reshape(B, 3, -1, L)
  818. patch_attn_mask = torch.zeros((B, 1, max_patches),
  819. dtype=torch.bool,
  820. device=device)
  821. for i in range(B):
  822. patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
  823. vision_embedding = self.vpm(
  824. all_pixel_values.type(dtype),
  825. patch_attention_mask=patch_attn_mask,
  826. tgt_sizes=tgt_sizes,
  827. ).last_hidden_state
  828. return self.resampler(vision_embedding, tgt_sizes)
  829. def is_default_weight_loading(self, name: str) -> bool:
  830. return "resampler" in name or "vpm" in name
  831. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  832. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
  833. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
  834. @INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv)
  835. class MiniCPMV(MiniCPMVBaseModel):
  836. """
  837. Different versions of MiniCPMV use different visual encoders and LLMs,
  838. which is not conducive to the current integration logic of LoRA and
  839. bitsandbytes in aphrodite. Therefore, it is necessary to separate them.
  840. """
  841. def __new__(
  842. cls,
  843. config: PretrainedConfig,
  844. multimodal_config: MultiModalConfig,
  845. cache_config: Optional[CacheConfig] = None,
  846. quant_config: Optional[QuantizationConfig] = None,
  847. ):
  848. if not hasattr(config, "version"):
  849. if config.hidden_size == 2304 and config.query_num == 64:
  850. version = (2, 0)
  851. else:
  852. version = (2, 5)
  853. else:
  854. version = str(config.version).split(".")
  855. version = tuple([int(x) for x in version])
  856. # Dispatch class based on version
  857. if version == (2, 0):
  858. instance_class = MiniCPMV2
  859. elif version == (2, 5):
  860. instance_class = MiniCPMV2_5
  861. else:
  862. instance_class = MiniCPMVQwen2
  863. return instance_class(config, multimodal_config, cache_config,
  864. quant_config)