minicpmv.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only MiniCPM-V model compatible with HuggingFace weights."""
  24. import math
  25. import re
  26. from functools import partial
  27. from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
  28. TypedDict)
  29. import torch
  30. import torch.types
  31. from PIL import Image
  32. from torch import nn
  33. from torch.nn.init import trunc_normal_
  34. from transformers import PretrainedConfig
  35. from aphrodite.attention import AttentionMetadata
  36. from aphrodite.common.config import CacheConfig, MultiModalConfig
  37. from aphrodite.common.sequence import IntermediateTensors, SequenceData
  38. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  39. from aphrodite.modeling.layers.linear import ReplicatedLinear
  40. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  41. from aphrodite.modeling.layers.resampler import (Resampler2,
  42. get_2d_sincos_pos_embed)
  43. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  44. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  45. from aphrodite.modeling.model_loader.utils import set_default_torch_dtype
  46. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  47. from aphrodite.modeling.models.interfaces import SupportsMultiModal
  48. from aphrodite.modeling.models.llama import LlamaModel
  49. from aphrodite.modeling.models.minicpm import MiniCPMModel
  50. from aphrodite.modeling.models.qwen2 import Qwen2Model
  51. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  52. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  53. from aphrodite.multimodal.image import cached_get_image_processor
  54. from aphrodite.multimodal.utils import cached_get_tokenizer
  55. from aphrodite.quantization import QuantizationConfig
  56. from .idefics2_vision_model import Idefics2VisionTransformer
  57. _KEYS_TO_MODIFY_MAPPING = {
  58. "llm.lm_head": "lm_head",
  59. "llm.model": "llm",
  60. }
  61. class MiniCPMVImagePixelInputs(TypedDict):
  62. pixel_values: List[torch.Tensor]
  63. """
  64. Shape: `(batch_size * num_images, num_channels, height, width)`
  65. Note that the image size may vary, so we pass it as a list
  66. instead of a batched tensor.
  67. """
  68. image_bounds: torch.Tensor
  69. """
  70. Shape: `(batch_size * num_images, 2)`
  71. This should be in `(start, stop)` format.
  72. """
  73. tgt_sizes: torch.Tensor
  74. """
  75. Shape: `(batch_size * num_images, 2)`
  76. This should be in `(height, width)` format.
  77. """
  78. MiniCPMVImageInputs = MiniCPMVImagePixelInputs
  79. DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
  80. class BaseResampler(nn.Module):
  81. """
  82. A 2D perceiver-resampler network with one cross attention layers by
  83. (grid_size**2) learnable queries and 2d sincos pos_emb
  84. Outputs:
  85. A tensor with the shape of (grid_size**2, embed_dim)
  86. """
  87. def __init__(
  88. self,
  89. num_queries: int,
  90. embed_dim: int,
  91. num_heads: int,
  92. kv_dim: Optional[int] = None,
  93. norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
  94. ) -> None:
  95. super().__init__()
  96. self.num_queries = num_queries
  97. self.embed_dim = embed_dim
  98. self.num_heads = num_heads
  99. self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
  100. trunc_normal_(self.query, std=0.02)
  101. if kv_dim is not None and kv_dim != embed_dim:
  102. self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
  103. else:
  104. # Maintain the same return value with ReplicatedLinear.forward
  105. self.kv_proj = lambda *args, **kwargs: (
  106. nn.Identity()(*args, **kwargs),
  107. None,
  108. )
  109. self.attn = nn.MultiheadAttention(embed_dim, num_heads)
  110. self.ln_q = norm_layer(embed_dim)
  111. self.ln_kv = norm_layer(embed_dim)
  112. self.ln_post = norm_layer(embed_dim)
  113. self.proj = nn.Parameter(
  114. (embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
  115. def _init_weights(self, m: nn.Module) -> None:
  116. if isinstance(m, nn.Linear):
  117. trunc_normal_(m.weight, std=0.02)
  118. if isinstance(m, nn.Linear) and m.bias is not None:
  119. nn.init.constant_(m.bias, 0)
  120. elif isinstance(m, nn.LayerNorm):
  121. nn.init.constant_(m.bias, 0)
  122. nn.init.constant_(m.weight, 1.0)
  123. def _repeat(self, query, N: int):
  124. return query.unsqueeze(1).repeat(1, N, 1)
  125. class Resampler2_5(BaseResampler):
  126. def __init__(
  127. self,
  128. num_queries: int,
  129. embed_dim: int,
  130. num_heads: int,
  131. kv_dim: Optional[int] = None,
  132. norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
  133. max_size: Tuple[int, int] = (70, 70),
  134. ) -> None:
  135. super().__init__(num_queries, embed_dim, num_heads, kv_dim, norm_layer)
  136. self.max_size = max_size
  137. self._set_2d_pos_cache(self.max_size)
  138. self.apply(self._init_weights)
  139. def _set_2d_pos_cache(self,
  140. max_size: Tuple[int, int],
  141. device: torch.types.Device = "cpu") -> None:
  142. pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
  143. max_size,
  144. version=(2, 5))
  145. pos_embed = torch.from_numpy(pos_embed_arr).float().to(device)
  146. self.register_buffer("pos_embed", pos_embed, persistent=False)
  147. def _adjust_pos_cache(self, tgt_sizes: torch.Tensor,
  148. device: torch.types.Device) -> None:
  149. max_h = tgt_sizes[:, 0].max().item()
  150. max_w = tgt_sizes[:, 1].max().item()
  151. assert isinstance(max_h, int) and isinstance(max_w, int)
  152. if max_h > self.max_size[0] or max_w > self.max_size[1]:
  153. self.max_size = (
  154. max(max_h, self.max_size[0]),
  155. max(max_w, self.max_size[1]),
  156. )
  157. self._set_2d_pos_cache(self.max_size, device)
  158. def forward(self, x: torch.Tensor,
  159. tgt_sizes: torch.Tensor) -> torch.Tensor:
  160. assert x.shape[0] == tgt_sizes.shape[0]
  161. bs = x.shape[0]
  162. device = x.device
  163. dtype = x.dtype
  164. patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
  165. self._adjust_pos_cache(tgt_sizes, device=device)
  166. max_patch_len = patch_len.max().item()
  167. assert isinstance(max_patch_len, int)
  168. key_padding_mask = torch.zeros((bs, max_patch_len),
  169. dtype=torch.bool,
  170. device=device)
  171. pos_embed = []
  172. for i in range(bs):
  173. tgt_h, tgt_w = tgt_sizes[i].tolist()
  174. pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape(
  175. (tgt_h * tgt_w, -1)).to(dtype)) # patches * D
  176. key_padding_mask[i, patch_len[i]:] = True
  177. pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed,
  178. batch_first=True,
  179. padding_value=0.0).permute(
  180. 1, 0,
  181. 2) # BLD => L * B * D
  182. x, _ = self.kv_proj(x) # B * L * D
  183. x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
  184. q = self.ln_q(self.query) # Q * D
  185. out = self.attn(
  186. self._repeat(q, bs), # Q * B * D
  187. x + pos_embed, # L * B * D + L * B * D
  188. x,
  189. key_padding_mask=key_padding_mask,
  190. )[0]
  191. # out: Q * B * D
  192. x = out.permute(1, 0, 2) # B * Q * D
  193. x = self.ln_post(x)
  194. x = x @ self.proj
  195. return x
  196. def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
  197. version_float = getattr(config, "version", None)
  198. # The old configs do not include version number
  199. # TODO: Remove this after the HF repos are updated
  200. if version_float is None:
  201. if config.hidden_size == 2304 and config.query_num == 64:
  202. return (2, 0)
  203. return (2, 5)
  204. version_str = str(version_float)
  205. return tuple(int(x) for x in version_str.split("."))
  206. def get_max_minicpmv_image_tokens(ctx: InputContext):
  207. hf_config = ctx.get_hf_config()
  208. return getattr(hf_config, "query_num", 64)
  209. def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
  210. return SequenceData.from_token_counts((0, seq_len))
  211. def dummy_image_for_minicpmv(hf_config: PretrainedConfig, num_images: int):
  212. width = height = hf_config.image_size
  213. image = Image.new("RGB", (width, height), color=0)
  214. return {"image": image if num_images == 1 else [image] * num_images}
  215. def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
  216. mm_counts: Mapping[str, int]):
  217. hf_config = ctx.get_hf_config()
  218. num_images = mm_counts["image"]
  219. seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images)
  220. mm_data = dummy_image_for_minicpmv(hf_config, num_images)
  221. return seq_data, mm_data
  222. def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
  223. multi_modal_data = llm_inputs.get("multi_modal_data")
  224. if multi_modal_data is None or "image" not in multi_modal_data:
  225. return llm_inputs
  226. model_config = ctx.model_config
  227. version = get_version_by_config(model_config.hf_config)
  228. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  229. trust_remote_code=True)
  230. image_processor = cached_get_image_processor(model_config.tokenizer)
  231. def get_placeholder(image_size: Tuple[int, int], num_image: int):
  232. if version == (2, 0) or version == (2, 5):
  233. return image_processor. \
  234. get_slice_image_placeholder(image_size)
  235. return image_processor. \
  236. get_slice_image_placeholder(image_size, num_image)
  237. prompt = llm_inputs.get("prompt")
  238. if prompt is None:
  239. token_ids = llm_inputs.get("prompt_token_ids")
  240. prompt = tokenizer.decode(token_ids)
  241. pattern = "(<image>./</image>)"
  242. images = multi_modal_data["image"]
  243. if isinstance(images, Image.Image):
  244. images = [images]
  245. image_tags = re.findall(pattern, prompt)
  246. if len(image_tags) == 0:
  247. new_token_ids = token_ids
  248. new_prompt = prompt
  249. else:
  250. text_chunks = prompt.split(pattern)
  251. new_prompt_chunks: List[str] = []
  252. for i in range(len(images)):
  253. new_prompt_chunks += [
  254. text_chunks[i],
  255. get_placeholder(images[i].size, i)
  256. ]
  257. new_prompt_chunks.append(text_chunks[-1])
  258. new_prompt = "".join(new_prompt_chunks)
  259. new_token_ids = tokenizer.encode(new_prompt)
  260. llm_inputs = LLMInputs(
  261. prompt_token_ids=new_token_ids,
  262. prompt=new_prompt,
  263. multi_modal_data=multi_modal_data,
  264. )
  265. return llm_inputs
  266. class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
  267. """
  268. The abstract class of MiniCPMV can only be inherited, but cannot be
  269. instantiated.
  270. """
  271. def __init__(
  272. self,
  273. config: PretrainedConfig,
  274. multimodal_config: MultiModalConfig,
  275. cache_config: Optional[CacheConfig] = None,
  276. quant_config: Optional[QuantizationConfig] = None,
  277. ):
  278. super().__init__()
  279. # All MiniCPM-V models disable `tie_word_embeddings` but
  280. # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
  281. # check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
  282. # and config class
  283. self.config = config
  284. self.multimodal_config = multimodal_config
  285. self.version = get_version_by_config(self.config)
  286. self.llm = self.init_llm(config, cache_config, quant_config)
  287. self.vpm = self.init_vision_module()
  288. param_dtype = torch.get_default_dtype()
  289. self.vpm.to(dtype=param_dtype)
  290. self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
  291. self.vpm.embeddings.embed_dim)
  292. self.embed_dim = self.config.hidden_size
  293. self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
  294. self.resampler.to(device="cuda", dtype=param_dtype)
  295. self.lm_head = ParallelLMHead(config.vocab_size,
  296. config.hidden_size,
  297. quant_config=quant_config)
  298. self.logits_processor = LogitsProcessor(config.vocab_size)
  299. self.sampler = Sampler()
  300. def get_embedding(
  301. self,
  302. input_ids: torch.Tensor,
  303. image_inputs: Optional[MiniCPMVImageInputs],
  304. ) -> Tuple[torch.Tensor, torch.Tensor]:
  305. vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
  306. if hasattr(self.config, "scale_emb"):
  307. vlm_embedding *= self.config.scale_emb
  308. if image_inputs is None: # No image
  309. vision_hidden_states = torch.tensor([], device=input_ids.device)
  310. else:
  311. vision_hidden_states = self.get_vision_hidden_states(image_inputs)
  312. # See NOTE in _parse_and_validate_inputs
  313. image_bounds = image_inputs["image_bounds"]
  314. if len(image_bounds) > 0:
  315. image_indices = torch.stack([
  316. torch.arange(start, end, dtype=torch.long)
  317. for start, end in image_bounds.tolist()
  318. ]).to(vlm_embedding.device)
  319. vlm_embedding.scatter_(
  320. 0,
  321. image_indices.view(-1, 1).repeat(1,
  322. vlm_embedding.shape[-1]),
  323. vision_hidden_states.view(-1,
  324. vision_hidden_states.shape[-1]),
  325. )
  326. return vlm_embedding, vision_hidden_states
  327. def _get_image_bounds(self, input_ids: torch.Tensor) -> torch.Tensor:
  328. tokenizer = cached_get_tokenizer(self.config._name_or_path,
  329. trust_remote_code=True)
  330. start_cond = input_ids == tokenizer.im_start_id
  331. end_cond = input_ids == tokenizer.im_end_id
  332. if hasattr(tokenizer, "slice_start_id"):
  333. start_cond |= (input_ids == tokenizer.slice_start_id)
  334. end_cond |= (input_ids == tokenizer.slice_end_id)
  335. image_start_tokens, = torch.where(start_cond)
  336. image_start_tokens += 1
  337. image_end_tokens, = torch.where(end_cond)
  338. valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
  339. if valid_image_nums == 0:
  340. return torch.zeros((0, 2), device=input_ids.device)
  341. return torch.hstack([
  342. image_start_tokens[:valid_image_nums].unsqueeze(-1),
  343. image_end_tokens[:valid_image_nums].unsqueeze(-1),
  344. ])
  345. def _parse_and_validate_inputs(
  346. self,
  347. input_ids: torch.Tensor,
  348. **kwargs: object,
  349. ) -> Optional[MiniCPMVImageInputs]:
  350. pixel_values = kwargs.pop("pixel_values", [])
  351. tgt_sizes = kwargs.pop("tgt_sizes", [])
  352. if not isinstance(pixel_values, (torch.Tensor, list)):
  353. raise ValueError("Incorrect type of pixel values. "
  354. f"Got type: {type(pixel_values)}")
  355. if not isinstance(tgt_sizes, (torch.Tensor, list)):
  356. raise ValueError("Incorrect type of target sizes. "
  357. f"Got type: {type(tgt_sizes)}")
  358. if len(pixel_values) != len(tgt_sizes):
  359. raise ValueError("Inconsistent batch lengths, found: "
  360. f"{len(pixel_values)} vs. {len(tgt_sizes)}")
  361. pixel_values_flat: List[torch.Tensor] = []
  362. tgt_sizes_flat: List[torch.Tensor] = []
  363. for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
  364. if len(pixel_b) != len(tgt_b):
  365. raise ValueError("Inconsistent N lengths, found: "
  366. f"{len(pixel_b)} vs {len(tgt_b)}")
  367. for pixel_n, tgt_n in zip(pixel_b, tgt_b):
  368. pixel_values_flat += pixel_n
  369. tgt_sizes_flat += tgt_n
  370. # NOTE: Input IDs does not contain image tokens during memory profiling,
  371. # so we allow it to be empty
  372. if len(pixel_values_flat) != len(tgt_sizes_flat):
  373. raise ValueError("Inconsistent flattened lengths, found: "
  374. f"{len(pixel_values_flat)} vs. "
  375. f"{len(tgt_sizes_flat)}")
  376. if len(pixel_values_flat) == 0:
  377. return None
  378. return MiniCPMVImageInputs(
  379. image_bounds=self._get_image_bounds(input_ids),
  380. pixel_values=pixel_values_flat,
  381. tgt_sizes=torch.stack(tgt_sizes_flat),
  382. )
  383. def forward(
  384. self,
  385. input_ids: torch.Tensor,
  386. positions: torch.Tensor,
  387. kv_caches: List[torch.Tensor],
  388. attn_metadata: AttentionMetadata,
  389. intermediate_tensors: Optional[IntermediateTensors] = None,
  390. **kwargs: Any,
  391. ) -> torch.Tensor:
  392. image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
  393. vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
  394. output = self.llm(
  395. input_ids=None,
  396. positions=positions,
  397. kv_caches=kv_caches,
  398. attn_metadata=attn_metadata,
  399. intermediate_tensors=intermediate_tensors,
  400. inputs_embeds=vlm_embeddings,
  401. )
  402. return output
  403. def compute_logits(
  404. self,
  405. hidden_states: torch.Tensor,
  406. sampling_metadata: SamplingMetadata,
  407. ) -> Optional[torch.Tensor]:
  408. logits = self.logits_processor(self.lm_head, hidden_states,
  409. sampling_metadata)
  410. return logits
  411. def sample(
  412. self,
  413. logits: torch.Tensor,
  414. sampling_metadata: SamplingMetadata,
  415. ) -> Optional[SamplerOutput]:
  416. next_tokens = self.sampler(logits, sampling_metadata)
  417. return next_tokens
  418. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  419. stacked_params_mapping = [
  420. # (param_name, shard_name, shard_id)
  421. ("qkv_proj", "q_proj", "q"),
  422. ("qkv_proj", "k_proj", "k"),
  423. ("qkv_proj", "v_proj", "v"),
  424. ("gate_up_proj", "gate_proj", 0),
  425. ("gate_up_proj", "up_proj", 1),
  426. ]
  427. params_dict = dict(self.named_parameters())
  428. for name, loaded_weight in weights:
  429. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  430. if key_to_modify in name:
  431. name = name.replace(key_to_modify, new_key)
  432. if "rotary_emb.inv_freq" in name:
  433. continue
  434. if ("rotary_emb.cos_cached" in name
  435. or "rotary_emb.sin_cached" in name):
  436. # Models trained using ColossalAI may include these tensors in
  437. # the checkpoint. Skip them.
  438. continue
  439. use_default_weight_loading = False
  440. if self.is_default_weight_loading(name):
  441. use_default_weight_loading = True
  442. else:
  443. for param_name, weight_name, shard_id in stacked_params_mapping:
  444. if weight_name not in name:
  445. continue
  446. param = params_dict[name.replace(weight_name, param_name)]
  447. weight_loader = param.weight_loader
  448. weight_loader(param, loaded_weight, shard_id)
  449. break
  450. else:
  451. use_default_weight_loading = True
  452. if use_default_weight_loading:
  453. param = params_dict[name]
  454. weight_loader = getattr(param, "weight_loader",
  455. default_weight_loader)
  456. weight_loader(param, loaded_weight)
  457. def init_llm(
  458. self,
  459. config: PretrainedConfig,
  460. cache_config: Optional[CacheConfig] = None,
  461. quant_config: Optional[QuantizationConfig] = None,
  462. ) -> nn.Module:
  463. raise NotImplementedError
  464. def init_vision_module(self) -> nn.Module:
  465. raise NotImplementedError
  466. def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
  467. raise NotImplementedError
  468. def get_vision_embedding(
  469. self,
  470. pixel_values: List[torch.Tensor],
  471. patch_attn_mask: Optional[torch.Tensor] = None,
  472. tgt_sizes: Optional[torch.Tensor] = None,
  473. ) -> torch.Tensor:
  474. raise NotImplementedError
  475. def get_vision_hidden_states(self,
  476. data: MiniCPMVImageInputs) -> torch.Tensor:
  477. raise NotImplementedError
  478. def is_default_weight_loading(self, name: str) -> bool:
  479. raise NotImplementedError
  480. class MiniCPMV2_0(MiniCPMVBaseModel):
  481. def __init__(
  482. self,
  483. config: PretrainedConfig,
  484. multimodal_config: MultiModalConfig,
  485. cache_config: Optional[CacheConfig] = None,
  486. quant_config: Optional[QuantizationConfig] = None,
  487. ):
  488. super().__init__(config, multimodal_config, cache_config, quant_config)
  489. assert self.version == (2, 0)
  490. def init_llm(
  491. self,
  492. config: PretrainedConfig,
  493. cache_config: Optional[CacheConfig] = None,
  494. quant_config: Optional[QuantizationConfig] = None,
  495. ) -> nn.Module:
  496. return MiniCPMModel(config,
  497. cache_config=cache_config,
  498. quant_config=quant_config)
  499. def init_vision_module(self) -> nn.Module:
  500. # TODO :refactor this vision model
  501. try:
  502. import timm
  503. except ImportError:
  504. raise ImportError("Please install timm==0.9.10") from ImportError
  505. with set_default_torch_dtype(torch.float16):
  506. model = timm.create_model(
  507. "vit_so400m_patch14_siglip_384.webli",
  508. pretrained=False,
  509. num_classes=0,
  510. dynamic_img_size=True,
  511. dynamic_img_pad=True,
  512. )
  513. if (isinstance(model, timm.models.VisionTransformer)
  514. and model.attn_pool is not None):
  515. model.attn_pool = torch.nn.Identity()
  516. if self.config.drop_vision_last_layer:
  517. model.blocks = model.blocks[:-1]
  518. return model
  519. def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
  520. with set_default_torch_dtype(torch.float16):
  521. resampler = Resampler2(
  522. embed_dim=embed_dim,
  523. num_heads=embed_dim // 128,
  524. grid_size=int(math.sqrt(self.config.query_num)),
  525. kv_dim=vision_dim,
  526. adaptive=False,
  527. do_post_projection=True,
  528. )
  529. return resampler
  530. def get_vision_embedding(
  531. self,
  532. pixel_values: List[torch.Tensor],
  533. patch_attn_mask: Optional[torch.Tensor] = None,
  534. tgt_sizes: Optional[torch.Tensor] = None,
  535. ) -> torch.Tensor:
  536. res = []
  537. dtype = self.vpm.pos_embed.data.dtype
  538. for pixel_value in pixel_values:
  539. H, W = pixel_value[0].shape[-2:]
  540. tgt_size = (
  541. math.ceil(H / self.vpm.patch_embed.patch_size[0]),
  542. math.ceil(W / self.vpm.patch_embed.patch_size[0]),
  543. )
  544. vision_embedding = self.vpm.forward_features(
  545. pixel_value.unsqueeze(0).type(dtype))
  546. if (hasattr(self.vpm, "num_prefix_tokens")
  547. and self.vpm.num_prefix_tokens > 0):
  548. vision_embedding = vision_embedding[:, self.vpm.
  549. num_prefix_tokens:]
  550. res.append(self.resampler(vision_embedding, tgt_size))
  551. return torch.vstack(res)
  552. def get_vision_hidden_states(self,
  553. data: MiniCPMVImageInputs) -> torch.Tensor:
  554. pixel_values = data["pixel_values"]
  555. return self.get_vision_embedding(pixel_values)
  556. def is_default_weight_loading(self, name: str) -> bool:
  557. return "resampler" in name or "vpm" in name
  558. class MiniCPMV2_5(MiniCPMVBaseModel):
  559. def __init__(
  560. self,
  561. config: PretrainedConfig,
  562. multimodal_config: MultiModalConfig,
  563. cache_config: Optional[CacheConfig] = None,
  564. quant_config: Optional[QuantizationConfig] = None,
  565. ):
  566. super().__init__(config, multimodal_config, cache_config, quant_config)
  567. assert self.version == (2, 5)
  568. def init_llm(
  569. self,
  570. config: PretrainedConfig,
  571. cache_config: Optional[CacheConfig] = None,
  572. quant_config: Optional[QuantizationConfig] = None,
  573. ) -> nn.Module:
  574. return LlamaModel(config,
  575. cache_config=cache_config,
  576. quant_config=quant_config)
  577. def init_vision_module(self) -> nn.Module:
  578. model = Idefics2VisionTransformer(self.config.vision_config)
  579. if self.config.drop_vision_last_layer:
  580. model.encoder.layers = model.encoder.layers[:-1]
  581. return model
  582. def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
  583. with set_default_torch_dtype(torch.float16):
  584. resampler = Resampler2_5(
  585. num_queries=self.config.query_num,
  586. embed_dim=embed_dim,
  587. num_heads=embed_dim // 128,
  588. kv_dim=vision_dim,
  589. )
  590. return resampler
  591. def get_vision_embedding(
  592. self,
  593. pixel_values: List[torch.Tensor],
  594. patch_attn_mask: Optional[torch.Tensor] = None,
  595. tgt_sizes: Optional[torch.Tensor] = None,
  596. ) -> torch.Tensor:
  597. vision_embedding = self.vpm(pixel_values,
  598. patch_attention_mask=patch_attn_mask)
  599. vision_embedding = self.resampler(vision_embedding, tgt_sizes)
  600. return vision_embedding
  601. def get_vision_hidden_states(self,
  602. data: MiniCPMVImageInputs) -> torch.Tensor:
  603. pixel_values = data["pixel_values"]
  604. tgt_sizes = data["tgt_sizes"]
  605. device = self.vpm.embeddings.position_embedding.weight.device
  606. dtype = self.vpm.embeddings.position_embedding.weight.dtype
  607. all_pixel_values_lst = [
  608. i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
  609. ]
  610. max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
  611. assert isinstance(max_patches, int)
  612. all_pixel_values = torch.nn.utils.rnn.pad_sequence(
  613. all_pixel_values_lst, batch_first=True, padding_value=0.0)
  614. B, L, _ = all_pixel_values.shape
  615. all_pixel_values = all_pixel_values.permute(0, 2,
  616. 1).reshape(B, 3, -1, L)
  617. patch_attn_mask = torch.zeros((B, 1, max_patches),
  618. dtype=torch.bool,
  619. device=device)
  620. for i in range(B):
  621. patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
  622. return self.get_vision_embedding(all_pixel_values.type(dtype),
  623. patch_attn_mask, tgt_sizes)
  624. def is_default_weight_loading(self, name: str) -> bool:
  625. return "resampler" in name
  626. class MiniCPMV2_6(MiniCPMVBaseModel):
  627. def __init__(
  628. self,
  629. config: PretrainedConfig,
  630. multimodal_config: MultiModalConfig,
  631. cache_config: Optional[CacheConfig] = None,
  632. quant_config: Optional[QuantizationConfig] = None,
  633. ):
  634. super().__init__(config, multimodal_config, cache_config, quant_config)
  635. assert self.version == (2, 6)
  636. def init_llm(
  637. self,
  638. config: PretrainedConfig,
  639. cache_config: Optional[CacheConfig] = None,
  640. quant_config: Optional[QuantizationConfig] = None,
  641. ) -> nn.Module:
  642. return Qwen2Model(config,
  643. cache_config=cache_config,
  644. quant_config=quant_config)
  645. def init_vision_module(self) -> nn.Module:
  646. # A custom version of SiglipVisionTransformer, won't work with TP
  647. from aphrodite.modeling.models.na_vit import SiglipVisionTransformer
  648. if self.config._attn_implementation == "flash_attention_2":
  649. self.config.vision_config._attn_implementation = "flash_attention_2"
  650. else:
  651. # not support sdpa
  652. self.config.vision_config._attn_implementation = "eager"
  653. model = SiglipVisionTransformer(self.config.vision_config)
  654. if self.config.drop_vision_last_layer:
  655. model.encoder.layers = model.encoder.layers[:-1]
  656. return model
  657. def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
  658. with set_default_torch_dtype(torch.float16):
  659. # The resampler in 2.6 remains consistent with the one in 2.5.
  660. resampler = Resampler2_5(
  661. num_queries=self.config.query_num,
  662. embed_dim=embed_dim,
  663. num_heads=embed_dim // 128,
  664. kv_dim=vision_dim,
  665. )
  666. return resampler
  667. def get_vision_embedding(
  668. self,
  669. pixel_values: List[torch.Tensor],
  670. patch_attn_mask: Optional[torch.Tensor] = None,
  671. tgt_sizes: Optional[torch.Tensor] = None,
  672. ) -> torch.Tensor:
  673. vision_embedding = self.vpm(
  674. pixel_values,
  675. patch_attention_mask=patch_attn_mask,
  676. tgt_sizes=tgt_sizes,
  677. ).last_hidden_state
  678. return vision_embedding
  679. def get_vision_hidden_states(self,
  680. data: MiniCPMVImageInputs) -> torch.Tensor:
  681. pixel_values = data["pixel_values"]
  682. tgt_sizes = data["tgt_sizes"]
  683. device = self.vpm.embeddings.position_embedding.weight.device
  684. dtype = self.vpm.embeddings.position_embedding.weight.dtype
  685. all_pixel_values_lst = [
  686. i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
  687. ]
  688. max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
  689. assert isinstance(max_patches, int)
  690. all_pixel_values = torch.nn.utils.rnn.pad_sequence(
  691. all_pixel_values_lst, batch_first=True, padding_value=0.0)
  692. B, L, _ = all_pixel_values.shape
  693. all_pixel_values = all_pixel_values.permute(0, 2,
  694. 1).reshape(B, 3, -1, L)
  695. patch_attn_mask = torch.zeros((B, 1, max_patches),
  696. dtype=torch.bool,
  697. device=device)
  698. for i in range(B):
  699. patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
  700. vision_embedding = self.vpm(
  701. all_pixel_values.type(dtype),
  702. patch_attention_mask=patch_attn_mask,
  703. tgt_sizes=tgt_sizes,
  704. ).last_hidden_state
  705. return self.resampler(vision_embedding, tgt_sizes)
  706. def is_default_weight_loading(self, name: str) -> bool:
  707. return "resampler" in name or "vpm" in name
  708. _SUPPORT_VERSION = {
  709. (2, 0): MiniCPMV2_0,
  710. (2, 5): MiniCPMV2_5,
  711. (2, 6): MiniCPMV2_6
  712. }
  713. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  714. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
  715. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
  716. @INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv)
  717. class MiniCPMV(MiniCPMVBaseModel):
  718. """
  719. Different versions of MiniCPMV use different visual encoders and LLMs,
  720. which is not conducive to the current integration logic of LoRA and
  721. bitsandbytes in vLLM. Therefore, it is necessary to separate them.
  722. """
  723. def __new__(
  724. cls,
  725. config: PretrainedConfig,
  726. multimodal_config: MultiModalConfig,
  727. cache_config: Optional[CacheConfig] = None,
  728. quant_config: Optional[QuantizationConfig] = None,
  729. ):
  730. if not hasattr(config, "version"):
  731. if config.hidden_size == 2304 and config.query_num == 64:
  732. version = (2, 0)
  733. else:
  734. version = (2, 5)
  735. else:
  736. version = str(config.version).split(".")
  737. version = tuple([int(x) for x in version])
  738. # Dispatch class based on version
  739. instance_class = _SUPPORT_VERSION.get(version, None)
  740. if instance_class is None:
  741. raise ValueError(
  742. "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
  743. return instance_class(config, multimodal_config, cache_config,
  744. quant_config)