minicpmv.py 38 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only MiniCPM-V model compatible with HuggingFace weights."""
  24. import math
  25. import re
  26. from functools import partial
  27. from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
  28. TypedDict, Union)
  29. import numpy as np
  30. import torch
  31. import torch.nn.functional as F
  32. import torch.types
  33. from PIL import Image
  34. from torch import nn
  35. from torch.nn.init import trunc_normal_
  36. from transformers.configuration_utils import PretrainedConfig
  37. from aphrodite.attention import AttentionMetadata
  38. from aphrodite.common.config import CacheConfig, MultiModalConfig
  39. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  40. SequenceData)
  41. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  42. from aphrodite.modeling.layers.linear import ReplicatedLinear
  43. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  44. from aphrodite.modeling.layers.sampler import Sampler
  45. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  46. from aphrodite.modeling.model_loader.utils import set_default_torch_dtype
  47. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  48. from aphrodite.modeling.models.interfaces import SupportsMultiModal
  49. from aphrodite.modeling.models.llama import LlamaModel
  50. from aphrodite.modeling.models.minicpm import MiniCPMModel
  51. from aphrodite.modeling.models.qwen2 import Qwen2Model
  52. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  53. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  54. from aphrodite.multimodal.image import (cached_get_image_processor,
  55. cached_get_tokenizer)
  56. from aphrodite.quantization.base_config import QuantizationConfig
  57. from .idefics2_vision_model import Idefics2VisionTransformer
  58. _KEYS_TO_MODIFY_MAPPING = {
  59. "llm.lm_head": "lm_head",
  60. "llm.model": "llm",
  61. }
  62. class MiniCPMVImagePixelInputs(TypedDict):
  63. pixel_values: List[torch.Tensor]
  64. """
  65. Shape: `(batch_size * num_images, num_channels, height, width)`
  66. Note that the image size may vary, so we pass it as a list
  67. instead of a batched tensor.
  68. """
  69. image_bounds: torch.Tensor
  70. """
  71. Shape: `(batch_size * num_images, 2)`
  72. This should be in `(start, stop)` format.
  73. """
  74. tgt_sizes: torch.Tensor
  75. """
  76. Shape: `(batch_size * num_images, 2)`
  77. This should be in `(height, width)` format.
  78. """
  79. MiniCPMVImageInputs = MiniCPMVImagePixelInputs
  80. DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
  81. def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor):
  82. # abs_pos: L, C
  83. # tgt_size: (H, W)
  84. # return: M, C
  85. src_size = int(math.sqrt(abs_pos.size(0)))
  86. # tgt_size = int(math.sqrt(tgt_size))
  87. dtype = abs_pos.dtype
  88. return (F.interpolate(
  89. abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
  90. size=(tgt_size[0], tgt_size[1]),
  91. mode="bicubic",
  92. align_corners=False,
  93. ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype))
  94. # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
  95. def get_2d_sincos_pos_embed(
  96. embed_dim: int,
  97. grid_size: Union[int, Tuple[int, int]],
  98. cls_token: bool = False,
  99. version: Tuple[int, int] = (2, 0),
  100. ):
  101. """
  102. grid_size: int of the grid height and width
  103. return:
  104. pos_embed: [grid_size*grid_size, embed_dim] or
  105. [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
  106. """
  107. if isinstance(grid_size, int):
  108. grid_h_size, grid_w_size = grid_size, grid_size
  109. else:
  110. grid_h_size, grid_w_size = grid_size[0], grid_size[1]
  111. grid_h = np.arange(grid_h_size, dtype=np.float32)
  112. grid_w = np.arange(grid_w_size, dtype=np.float32)
  113. grid = np.meshgrid(grid_w, grid_h) # here w goes first
  114. grid = np.stack(grid, axis=0)
  115. if version == (2, 0):
  116. grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
  117. pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
  118. if cls_token:
  119. pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed],
  120. axis=0)
  121. else:
  122. pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
  123. return pos_embed
  124. def get_2d_sincos_pos_embed_from_grid(embed_dim: int,
  125. grid: np.ndarray,
  126. version: Tuple[int, int] = (2, 0)):
  127. assert embed_dim % 2 == 0
  128. # use half of dimensions to encode grid_h
  129. emb_h = get_1d_sincos_pos_embed_from_grid(
  130. embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2)
  131. emb_w = get_1d_sincos_pos_embed_from_grid(
  132. embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2)
  133. if version == (2, 0):
  134. emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
  135. else:
  136. emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
  137. return emb
  138. def get_1d_sincos_pos_embed_from_grid(embed_dim: int,
  139. pos: np.ndarray,
  140. version: Tuple[int, int] = (2, 0)):
  141. """
  142. embed_dim: output dimension for each position
  143. pos: a list of positions to be encoded: size (M,) / (H, W)
  144. out: (M, D) / (H, W, D)
  145. """
  146. assert embed_dim % 2 == 0
  147. omega = np.arange(embed_dim // 2, dtype=np.float32)
  148. omega /= embed_dim / 2.0
  149. omega = 1.0 / 10000**omega # (D/2,)
  150. if version == (2, 0):
  151. pos = pos.reshape(-1) # (M,)
  152. out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
  153. emb_sin = np.sin(out) # (M, D/2)
  154. emb_cos = np.cos(out) # (M, D/2)
  155. emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
  156. else:
  157. out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
  158. emb_sin = np.sin(out) # (H, W, D/2)
  159. emb_cos = np.cos(out) # (H, W, D/2)
  160. emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
  161. return emb
  162. class BaseResampler(nn.Module):
  163. """
  164. A 2D perceiver-resampler network with one cross attention layers by
  165. (grid_size**2) learnable queries and 2d sincos pos_emb
  166. Outputs:
  167. A tensor with the shape of (grid_size**2, embed_dim)
  168. """
  169. def __init__(
  170. self,
  171. num_queries: int,
  172. embed_dim: int,
  173. num_heads: int,
  174. kv_dim: Optional[int] = None,
  175. norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
  176. ) -> None:
  177. super().__init__()
  178. self.num_queries = num_queries
  179. self.embed_dim = embed_dim
  180. self.num_heads = num_heads
  181. self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
  182. trunc_normal_(self.query, std=0.02)
  183. if kv_dim is not None and kv_dim != embed_dim:
  184. self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
  185. else:
  186. # Maintain the same return value with ReplicatedLinear.forward
  187. self.kv_proj = lambda *args, **kwargs: (
  188. nn.Identity()(*args, **kwargs),
  189. None,
  190. )
  191. self.attn = nn.MultiheadAttention(embed_dim, num_heads)
  192. self.ln_q = norm_layer(embed_dim)
  193. self.ln_kv = norm_layer(embed_dim)
  194. self.ln_post = norm_layer(embed_dim)
  195. self.proj = nn.Parameter(
  196. (embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
  197. def _init_weights(self, m: nn.Module) -> None:
  198. if isinstance(m, nn.Linear):
  199. trunc_normal_(m.weight, std=0.02)
  200. if isinstance(m, nn.Linear) and m.bias is not None:
  201. nn.init.constant_(m.bias, 0)
  202. elif isinstance(m, nn.LayerNorm):
  203. nn.init.constant_(m.bias, 0)
  204. nn.init.constant_(m.weight, 1.0)
  205. def _repeat(self, query, N: int):
  206. return query.unsqueeze(1).repeat(1, N, 1)
  207. class Resampler2(BaseResampler):
  208. def __init__(
  209. self,
  210. grid_size: int,
  211. embed_dim: int,
  212. num_heads: int,
  213. kv_dim: Optional[int] = None,
  214. norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
  215. adaptive: bool = False,
  216. ) -> None:
  217. super().__init__(grid_size**2, embed_dim, num_heads, kv_dim,
  218. norm_layer)
  219. self.adaptive = adaptive
  220. pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
  221. grid_size,
  222. version=(2, 0))
  223. self.pos_embed = nn.Parameter(
  224. torch.from_numpy(pos_embed_arr).float()).requires_grad_(False)
  225. self.apply(self._init_weights)
  226. def forward(
  227. self,
  228. x: torch.Tensor,
  229. tgt_sizes: torch.Tensor,
  230. attn_mask: Optional[torch.Tensor] = None,
  231. ):
  232. if self.adaptive:
  233. pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
  234. tgt_sizes,
  235. version=(2, 0))
  236. pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device,
  237. dtype=x.dtype)
  238. else:
  239. pos_embed = get_abs_pos(self.pos_embed, tgt_sizes)
  240. x, _ = self.kv_proj(x)
  241. x = self.ln_kv(x).permute(1, 0, 2)
  242. N = x.shape[1]
  243. q = self.ln_q(self.query)
  244. out = self.attn(
  245. self._repeat(q, N) + self.pos_embed.unsqueeze(1),
  246. x + pos_embed.unsqueeze(1),
  247. x,
  248. attn_mask=attn_mask,
  249. )[0]
  250. x = out.permute(1, 0, 2)
  251. x = self.ln_post(x)
  252. x = x @ self.proj
  253. return x
  254. class Resampler2_5(BaseResampler):
  255. def __init__(
  256. self,
  257. num_queries: int,
  258. embed_dim: int,
  259. num_heads: int,
  260. kv_dim: Optional[int] = None,
  261. norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
  262. max_size: Tuple[int, int] = (70, 70),
  263. ) -> None:
  264. super().__init__(num_queries, embed_dim, num_heads, kv_dim, norm_layer)
  265. self.max_size = max_size
  266. self._set_2d_pos_cache(self.max_size)
  267. self.apply(self._init_weights)
  268. def _set_2d_pos_cache(self,
  269. max_size: Tuple[int, int],
  270. device: torch.types.Device = "cpu") -> None:
  271. pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
  272. max_size,
  273. version=(2, 5))
  274. pos_embed = torch.from_numpy(pos_embed_arr).float().to(device)
  275. self.register_buffer("pos_embed", pos_embed, persistent=False)
  276. def _adjust_pos_cache(self, tgt_sizes: torch.Tensor,
  277. device: torch.types.Device) -> None:
  278. max_h = tgt_sizes[:, 0].max().item()
  279. max_w = tgt_sizes[:, 1].max().item()
  280. assert isinstance(max_h, int) and isinstance(max_w, int)
  281. if max_h > self.max_size[0] or max_w > self.max_size[1]:
  282. self.max_size = (
  283. max(max_h, self.max_size[0]),
  284. max(max_w, self.max_size[1]),
  285. )
  286. self._set_2d_pos_cache(self.max_size, device)
  287. def forward(self, x: torch.Tensor,
  288. tgt_sizes: torch.Tensor) -> torch.Tensor:
  289. assert x.shape[0] == tgt_sizes.shape[0]
  290. bs = x.shape[0]
  291. device = x.device
  292. dtype = x.dtype
  293. patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
  294. self._adjust_pos_cache(tgt_sizes, device=device)
  295. max_patch_len = patch_len.max().item()
  296. assert isinstance(max_patch_len, int)
  297. key_padding_mask = torch.zeros((bs, max_patch_len),
  298. dtype=torch.bool,
  299. device=device)
  300. pos_embed = []
  301. for i in range(bs):
  302. tgt_h, tgt_w = tgt_sizes[i].tolist()
  303. pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape(
  304. (tgt_h * tgt_w, -1)).to(dtype)) # patches * D
  305. key_padding_mask[i, patch_len[i]:] = True
  306. pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed,
  307. batch_first=True,
  308. padding_value=0.0).permute(
  309. 1, 0,
  310. 2) # BLD => L * B * D
  311. x, _ = self.kv_proj(x) # B * L * D
  312. x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
  313. q = self.ln_q(self.query) # Q * D
  314. out = self.attn(
  315. self._repeat(q, bs), # Q * B * D
  316. x + pos_embed, # L * B * D + L * B * D
  317. x,
  318. key_padding_mask=key_padding_mask,
  319. )[0]
  320. # out: Q * B * D
  321. x = out.permute(1, 0, 2) # B * Q * D
  322. x = self.ln_post(x)
  323. x = x @ self.proj
  324. return x
  325. def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
  326. version_float = getattr(config, "version", None)
  327. # The old configs do not include version number
  328. # TODO: Remove this after the HF repos are updated
  329. if version_float is None:
  330. if config.hidden_size == 2304 and config.query_num == 64:
  331. return (2, 0)
  332. return (2, 5)
  333. version_str = str(version_float)
  334. return tuple(int(x) for x in version_str.split("."))
  335. def get_max_minicpmv_image_tokens(ctx: InputContext):
  336. hf_config = ctx.get_hf_config(PretrainedConfig)
  337. return getattr(hf_config, "query_num", 64)
  338. def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
  339. token_ids = [0] * seq_len
  340. return SequenceData(token_ids)
  341. def dummy_image_for_minicpmv(hf_config: PretrainedConfig, num_images: int):
  342. width = height = hf_config.image_size
  343. image = Image.new("RGB", (width, height), color=0)
  344. return {"image": image if num_images == 1 else [image] * num_images}
  345. def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
  346. mm_counts: Mapping[str, int]):
  347. hf_config = ctx.get_hf_config()
  348. num_images = mm_counts["image"]
  349. seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images)
  350. mm_data = dummy_image_for_minicpmv(hf_config, num_images)
  351. return seq_data, mm_data
  352. def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
  353. multi_modal_data = llm_inputs.get("multi_modal_data")
  354. if multi_modal_data is None or "image" not in multi_modal_data:
  355. return llm_inputs
  356. model_config = ctx.model_config
  357. version = get_version_by_config(model_config.hf_config)
  358. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  359. trust_remote_code=True)
  360. image_processor = cached_get_image_processor(model_config.tokenizer)
  361. def get_placeholder(image_size: Tuple[int, int], num_image: int):
  362. if version == (2, 0) or version == (2, 5):
  363. return image_processor. \
  364. get_slice_image_placeholder(image_size)
  365. return image_processor. \
  366. get_slice_image_placeholder(image_size, num_image)
  367. prompt = llm_inputs.get("prompt")
  368. if prompt is None:
  369. token_ids = llm_inputs.get("prompt_token_ids")
  370. prompt = tokenizer.decode(token_ids)
  371. pattern = "(<image>./</image>)"
  372. images = multi_modal_data["image"]
  373. if isinstance(images, Image.Image):
  374. images = [images]
  375. image_tags = re.findall(pattern, prompt)
  376. if len(image_tags) == 0:
  377. new_token_ids = token_ids
  378. new_prompt = prompt
  379. else:
  380. text_chunks = prompt.split(pattern)
  381. new_prompt_chunks: List[str] = []
  382. for i in range(len(images)):
  383. new_prompt_chunks += [
  384. text_chunks[i],
  385. get_placeholder(images[i].size, i)
  386. ]
  387. new_prompt_chunks.append(text_chunks[-1])
  388. new_prompt = "".join(new_prompt_chunks)
  389. new_token_ids = tokenizer.encode(new_prompt)
  390. llm_inputs = LLMInputs(
  391. prompt_token_ids=new_token_ids,
  392. prompt=new_prompt,
  393. multi_modal_data=multi_modal_data,
  394. )
  395. return llm_inputs
  396. class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
  397. """
  398. The abstract class of MiniCPMV can only be inherited, but cannot be
  399. instantiated.
  400. """
  401. def __init__(
  402. self,
  403. config: PretrainedConfig,
  404. multimodal_config: MultiModalConfig,
  405. cache_config: Optional[CacheConfig] = None,
  406. quant_config: Optional[QuantizationConfig] = None,
  407. ):
  408. super().__init__()
  409. self.config = config
  410. self.multimodal_config = multimodal_config
  411. self.version = get_version_by_config(self.config)
  412. self.llm = self.init_llm(config, cache_config, quant_config)
  413. self.vpm = self.init_vision_module()
  414. param_dtype = torch.get_default_dtype()
  415. self.vpm.to(dtype=param_dtype)
  416. self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
  417. self.vpm.embeddings.embed_dim)
  418. self.embed_dim = self.config.hidden_size
  419. self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
  420. self.resampler.to(device="cuda", dtype=param_dtype)
  421. self.lm_head = ParallelLMHead(config.vocab_size,
  422. config.hidden_size,
  423. quant_config=quant_config)
  424. self.logits_processor = LogitsProcessor(config.vocab_size)
  425. self.sampler = Sampler()
  426. def get_embedding(
  427. self,
  428. input_ids: torch.Tensor,
  429. image_inputs: Optional[MiniCPMVImageInputs],
  430. ) -> Tuple[torch.Tensor, torch.Tensor]:
  431. vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
  432. if hasattr(self.config, "scale_emb"):
  433. vlm_embedding *= self.config.scale_emb
  434. if image_inputs is None: # No image
  435. vision_hidden_states = torch.tensor([], device=input_ids.device)
  436. else:
  437. vision_hidden_states = self.get_vision_hidden_states(image_inputs)
  438. # See NOTE in _parse_and_validate_inputs
  439. image_bounds = image_inputs["image_bounds"]
  440. if len(image_bounds) > 0:
  441. image_indices = torch.stack([
  442. torch.arange(start, end, dtype=torch.long)
  443. for start, end in image_bounds.tolist()
  444. ]).to(vlm_embedding.device)
  445. vlm_embedding.scatter_(
  446. 0,
  447. image_indices.view(-1, 1).repeat(1,
  448. vlm_embedding.shape[-1]),
  449. vision_hidden_states.view(-1,
  450. vision_hidden_states.shape[-1]),
  451. )
  452. return vlm_embedding, vision_hidden_states
  453. def _get_image_bounds(self, input_ids: torch.Tensor) -> torch.Tensor:
  454. tokenizer = cached_get_tokenizer(self.config._name_or_path,
  455. trust_remote_code=True)
  456. start_cond = input_ids == tokenizer.im_start_id
  457. end_cond = input_ids == tokenizer.im_end_id
  458. if hasattr(tokenizer, "slice_start_id"):
  459. start_cond |= (input_ids == tokenizer.slice_start_id)
  460. end_cond |= (input_ids == tokenizer.slice_end_id)
  461. image_start_tokens, = torch.where(start_cond)
  462. image_start_tokens += 1
  463. image_end_tokens, = torch.where(end_cond)
  464. valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
  465. if valid_image_nums == 0:
  466. return torch.zeros((0, 2), device=input_ids.device)
  467. return torch.hstack([
  468. image_start_tokens[:valid_image_nums].unsqueeze(-1),
  469. image_end_tokens[:valid_image_nums].unsqueeze(-1),
  470. ])
  471. def _parse_and_validate_inputs(
  472. self,
  473. input_ids: torch.Tensor,
  474. **kwargs: object,
  475. ) -> Optional[MiniCPMVImageInputs]:
  476. pixel_values = kwargs.pop("pixel_values", [])
  477. tgt_sizes = kwargs.pop("tgt_sizes", [])
  478. if not isinstance(pixel_values, (torch.Tensor, list)):
  479. raise ValueError("Incorrect type of pixel values. "
  480. f"Got type: {type(pixel_values)}")
  481. if not isinstance(tgt_sizes, (torch.Tensor, list)):
  482. raise ValueError("Incorrect type of target sizes. "
  483. f"Got type: {type(tgt_sizes)}")
  484. if len(pixel_values) != len(tgt_sizes):
  485. raise ValueError("Inconsistent batch lengths, found: "
  486. f"{len(pixel_values)} vs. {len(tgt_sizes)}")
  487. pixel_values_flat: List[torch.Tensor] = []
  488. tgt_sizes_flat: List[torch.Tensor] = []
  489. for b in range(len(pixel_values)):
  490. pixel_values_flat += pixel_values[b]
  491. tgt_sizes_flat += tgt_sizes[b]
  492. # NOTE: Input IDs does not contain image tokens during memory profiling,
  493. # so we allow it to be empty
  494. if len(pixel_values_flat) != len(tgt_sizes_flat):
  495. raise ValueError("Inconsistent flattened lengths, found: "
  496. f"{len(pixel_values_flat)} vs. "
  497. f"{len(tgt_sizes_flat)}")
  498. if len(pixel_values_flat) == 0:
  499. return None
  500. return MiniCPMVImageInputs(
  501. image_bounds=self._get_image_bounds(input_ids),
  502. pixel_values=pixel_values_flat,
  503. tgt_sizes=torch.stack(tgt_sizes_flat),
  504. )
  505. def forward(
  506. self,
  507. input_ids: torch.Tensor,
  508. positions: torch.Tensor,
  509. kv_caches: List[torch.Tensor],
  510. attn_metadata: AttentionMetadata,
  511. intermediate_tensors: Optional[IntermediateTensors] = None,
  512. **kwargs: Any,
  513. ) -> torch.Tensor:
  514. image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
  515. vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
  516. output = self.llm(
  517. input_ids=None,
  518. positions=positions,
  519. kv_caches=kv_caches,
  520. attn_metadata=attn_metadata,
  521. intermediate_tensors=intermediate_tensors,
  522. inputs_embeds=vlm_embeddings,
  523. )
  524. return output
  525. def compute_logits(
  526. self,
  527. hidden_states: torch.Tensor,
  528. sampling_metadata: SamplingMetadata,
  529. ) -> Optional[torch.Tensor]:
  530. logits = self.logits_processor(self.lm_head, hidden_states,
  531. sampling_metadata)
  532. return logits
  533. def sample(
  534. self,
  535. logits: torch.Tensor,
  536. sampling_metadata: SamplingMetadata,
  537. ) -> Optional[SamplerOutput]:
  538. next_tokens = self.sampler(logits, sampling_metadata)
  539. return next_tokens
  540. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  541. stacked_params_mapping = [
  542. # (param_name, shard_name, shard_id)
  543. ("qkv_proj", "q_proj", "q"),
  544. ("qkv_proj", "k_proj", "k"),
  545. ("qkv_proj", "v_proj", "v"),
  546. ("gate_up_proj", "gate_proj", 0),
  547. ("gate_up_proj", "up_proj", 1),
  548. ]
  549. params_dict = dict(self.named_parameters())
  550. for name, loaded_weight in weights:
  551. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  552. if key_to_modify in name:
  553. name = name.replace(key_to_modify, new_key)
  554. if "rotary_emb.inv_freq" in name:
  555. continue
  556. if ("rotary_emb.cos_cached" in name
  557. or "rotary_emb.sin_cached" in name):
  558. # Models trained using ColossalAI may include these tensors in
  559. # the checkpoint. Skip them.
  560. continue
  561. use_default_weight_loading = False
  562. if self.is_default_weight_loading(name):
  563. use_default_weight_loading = True
  564. else:
  565. for param_name, weight_name, shard_id in stacked_params_mapping:
  566. if weight_name not in name:
  567. continue
  568. param = params_dict[name.replace(weight_name, param_name)]
  569. weight_loader = param.weight_loader
  570. weight_loader(param, loaded_weight, shard_id)
  571. break
  572. else:
  573. use_default_weight_loading = True
  574. if use_default_weight_loading:
  575. param = params_dict[name]
  576. weight_loader = getattr(param, "weight_loader",
  577. default_weight_loader)
  578. weight_loader(param, loaded_weight)
  579. def init_llm(
  580. self,
  581. config: PretrainedConfig,
  582. cache_config: Optional[CacheConfig] = None,
  583. quant_config: Optional[QuantizationConfig] = None,
  584. ) -> nn.Module:
  585. raise NotImplementedError
  586. def init_vision_module(self) -> nn.Module:
  587. raise NotImplementedError
  588. def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
  589. raise NotImplementedError
  590. def get_vision_embedding(
  591. self,
  592. pixel_values: List[torch.Tensor],
  593. patch_attn_mask: Optional[torch.Tensor] = None,
  594. tgt_sizes: Optional[torch.Tensor] = None,
  595. ) -> torch.Tensor:
  596. raise NotImplementedError
  597. def get_vision_hidden_states(self,
  598. data: MiniCPMVImageInputs) -> torch.Tensor:
  599. raise NotImplementedError
  600. def is_default_weight_loading(self, name: str) -> bool:
  601. raise NotImplementedError
  602. class MiniCPMV2(MiniCPMVBaseModel):
  603. def __init__(
  604. self,
  605. config: PretrainedConfig,
  606. multimodal_config: MultiModalConfig,
  607. cache_config: Optional[CacheConfig] = None,
  608. quant_config: Optional[QuantizationConfig] = None,
  609. ):
  610. super().__init__(config, multimodal_config, cache_config, quant_config)
  611. assert self.version == (2, 0)
  612. def init_llm(
  613. self,
  614. config: PretrainedConfig,
  615. cache_config: Optional[CacheConfig] = None,
  616. quant_config: Optional[QuantizationConfig] = None,
  617. ) -> nn.Module:
  618. return MiniCPMModel(config,
  619. cache_config=cache_config,
  620. quant_config=quant_config)
  621. def init_vision_module(self) -> nn.Module:
  622. # TODO :refactor this vision model
  623. try:
  624. import timm
  625. except ImportError:
  626. raise ImportError("Please install timm==0.9.10") from ImportError
  627. with set_default_torch_dtype(torch.float16):
  628. model = timm.create_model(
  629. "vit_so400m_patch14_siglip_384.webli",
  630. pretrained=False,
  631. num_classes=0,
  632. dynamic_img_size=True,
  633. dynamic_img_pad=True,
  634. )
  635. if (isinstance(model, timm.models.VisionTransformer)
  636. and model.attn_pool is not None):
  637. model.attn_pool = torch.nn.Identity()
  638. if self.config.drop_vision_last_layer:
  639. model.blocks = model.blocks[:-1]
  640. return model
  641. def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
  642. with set_default_torch_dtype(torch.float16):
  643. resampler = Resampler2(
  644. embed_dim=embed_dim,
  645. num_heads=embed_dim // 128,
  646. grid_size=int(math.sqrt(self.config.query_num)),
  647. kv_dim=vision_dim,
  648. adaptive=True,
  649. )
  650. return resampler
  651. def get_vision_embedding(
  652. self,
  653. pixel_values: List[torch.Tensor],
  654. patch_attn_mask: Optional[torch.Tensor] = None,
  655. tgt_sizes: Optional[torch.Tensor] = None,
  656. ) -> torch.Tensor:
  657. res = []
  658. dtype = self.vpm.pos_embed.data.dtype
  659. for pixel_value in pixel_values:
  660. H, W = pixel_value[0].shape[-2:]
  661. tgt_size = (
  662. math.ceil(H / self.vpm.patch_embed.patch_size[0]),
  663. math.ceil(W / self.vpm.patch_embed.patch_size[0]),
  664. )
  665. vision_embedding = self.vpm.forward_features(
  666. pixel_value.unsqueeze(0).type(dtype))
  667. if (hasattr(self.vpm, "num_prefix_tokens")
  668. and self.vpm.num_prefix_tokens > 0):
  669. vision_embedding = vision_embedding[:, self.vpm.
  670. num_prefix_tokens:]
  671. res.append(self.resampler(vision_embedding, tgt_size))
  672. return torch.vstack(res)
  673. def get_vision_hidden_states(self,
  674. data: MiniCPMVImageInputs) -> torch.Tensor:
  675. pixel_values = data["pixel_values"]
  676. return self.get_vision_embedding(pixel_values)
  677. def is_default_weight_loading(self, name: str) -> bool:
  678. return "resampler" in name or "vpm" in name
  679. class MiniCPMV2_5(MiniCPMVBaseModel):
  680. def __init__(
  681. self,
  682. config: PretrainedConfig,
  683. multimodal_config: MultiModalConfig,
  684. cache_config: Optional[CacheConfig] = None,
  685. quant_config: Optional[QuantizationConfig] = None,
  686. ):
  687. super().__init__(config, multimodal_config, cache_config, quant_config)
  688. assert self.version == (2, 5)
  689. def init_llm(
  690. self,
  691. config: PretrainedConfig,
  692. cache_config: Optional[CacheConfig] = None,
  693. quant_config: Optional[QuantizationConfig] = None,
  694. ) -> nn.Module:
  695. return LlamaModel(config,
  696. cache_config=cache_config,
  697. quant_config=quant_config)
  698. def init_vision_module(self) -> nn.Module:
  699. model = Idefics2VisionTransformer(self.config.vision_config)
  700. if self.config.drop_vision_last_layer:
  701. model.encoder.layers = model.encoder.layers[:-1]
  702. return model
  703. def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
  704. with set_default_torch_dtype(torch.float16):
  705. resampler = Resampler2_5(
  706. num_queries=self.config.query_num,
  707. embed_dim=embed_dim,
  708. num_heads=embed_dim // 128,
  709. kv_dim=vision_dim,
  710. )
  711. return resampler
  712. def get_vision_embedding(
  713. self,
  714. pixel_values: List[torch.Tensor],
  715. patch_attn_mask: Optional[torch.Tensor] = None,
  716. tgt_sizes: Optional[torch.Tensor] = None,
  717. ) -> torch.Tensor:
  718. vision_embedding = self.vpm(pixel_values,
  719. patch_attention_mask=patch_attn_mask)
  720. vision_embedding = self.resampler(vision_embedding, tgt_sizes)
  721. return vision_embedding
  722. def get_vision_hidden_states(self,
  723. data: MiniCPMVImageInputs) -> torch.Tensor:
  724. pixel_values = data["pixel_values"]
  725. tgt_sizes = data["tgt_sizes"]
  726. device = self.vpm.embeddings.position_embedding.weight.device
  727. dtype = self.vpm.embeddings.position_embedding.weight.dtype
  728. all_pixel_values_lst = [
  729. i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
  730. ]
  731. max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
  732. assert isinstance(max_patches, int)
  733. all_pixel_values = torch.nn.utils.rnn.pad_sequence(
  734. all_pixel_values_lst, batch_first=True, padding_value=0.0)
  735. B, L, _ = all_pixel_values.shape
  736. all_pixel_values = all_pixel_values.permute(0, 2,
  737. 1).reshape(B, 3, -1, L)
  738. patch_attn_mask = torch.zeros((B, 1, max_patches),
  739. dtype=torch.bool,
  740. device=device)
  741. for i in range(B):
  742. patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
  743. return self.get_vision_embedding(all_pixel_values.type(dtype),
  744. patch_attn_mask, tgt_sizes)
  745. def is_default_weight_loading(self, name: str) -> bool:
  746. return "resampler" in name
  747. # NOTE: Currently, information about this model is unavailable. We are
  748. # temporarily using `MiniCPMVQwen2` as it's name. The name may need
  749. # to be modified in the future.
  750. class MiniCPMVQwen2(MiniCPMVBaseModel):
  751. def __init__(
  752. self,
  753. config: PretrainedConfig,
  754. multimodal_config: MultiModalConfig,
  755. cache_config: Optional[CacheConfig] = None,
  756. quant_config: Optional[QuantizationConfig] = None,
  757. ):
  758. super().__init__(config, multimodal_config, cache_config, quant_config)
  759. def init_llm(
  760. self,
  761. config: PretrainedConfig,
  762. cache_config: Optional[CacheConfig] = None,
  763. quant_config: Optional[QuantizationConfig] = None,
  764. ) -> nn.Module:
  765. return Qwen2Model(config,
  766. cache_config=cache_config,
  767. quant_config=quant_config)
  768. def init_vision_module(self) -> nn.Module:
  769. # A custom version of SiglipVisionTransformer, won't work with TP
  770. from aphrodite.modeling.models.na_vit import SiglipVisionTransformer
  771. if self.config._attn_implementation == "flash_attention_2":
  772. self.config.vision_config._attn_implementation = "flash_attention_2"
  773. else:
  774. # not support sdpa
  775. self.config.vision_config._attn_implementation = "eager"
  776. model = SiglipVisionTransformer(self.config.vision_config)
  777. if self.config.drop_vision_last_layer:
  778. model.encoder.layers = model.encoder.layers[:-1]
  779. return model
  780. def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
  781. with set_default_torch_dtype(torch.float16):
  782. resampler = Resampler2_5(
  783. num_queries=self.config.query_num,
  784. embed_dim=embed_dim,
  785. num_heads=embed_dim // 128,
  786. kv_dim=vision_dim,
  787. )
  788. return resampler
  789. def get_vision_embedding(
  790. self,
  791. pixel_values: List[torch.Tensor],
  792. patch_attn_mask: Optional[torch.Tensor] = None,
  793. tgt_sizes: Optional[torch.Tensor] = None,
  794. ) -> torch.Tensor:
  795. vision_embedding = self.vpm(
  796. pixel_values,
  797. patch_attention_mask=patch_attn_mask,
  798. tgt_sizes=tgt_sizes,
  799. ).last_hidden_state
  800. return vision_embedding
  801. def get_vision_hidden_states(self,
  802. data: MiniCPMVImageInputs) -> torch.Tensor:
  803. pixel_values = data["pixel_values"]
  804. tgt_sizes = data["tgt_sizes"]
  805. device = self.vpm.embeddings.position_embedding.weight.device
  806. dtype = self.vpm.embeddings.position_embedding.weight.dtype
  807. all_pixel_values_lst = [
  808. i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
  809. ]
  810. max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
  811. assert isinstance(max_patches, int)
  812. all_pixel_values = torch.nn.utils.rnn.pad_sequence(
  813. all_pixel_values_lst, batch_first=True, padding_value=0.0)
  814. B, L, _ = all_pixel_values.shape
  815. all_pixel_values = all_pixel_values.permute(0, 2,
  816. 1).reshape(B, 3, -1, L)
  817. patch_attn_mask = torch.zeros((B, 1, max_patches),
  818. dtype=torch.bool,
  819. device=device)
  820. for i in range(B):
  821. patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
  822. vision_embedding = self.vpm(
  823. all_pixel_values.type(dtype),
  824. patch_attention_mask=patch_attn_mask,
  825. tgt_sizes=tgt_sizes,
  826. ).last_hidden_state
  827. return self.resampler(vision_embedding, tgt_sizes)
  828. def is_default_weight_loading(self, name: str) -> bool:
  829. return "resampler" in name or "vpm" in name
  830. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  831. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
  832. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
  833. @INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv)
  834. class MiniCPMV(MiniCPMVBaseModel):
  835. """
  836. Different versions of MiniCPMV use different visual encoders and LLMs,
  837. which is not conducive to the current integration logic of LoRA and
  838. bitsandbytes in aphrodite. Therefore, it is necessary to separate them.
  839. """
  840. def __new__(
  841. cls,
  842. config: PretrainedConfig,
  843. multimodal_config: MultiModalConfig,
  844. cache_config: Optional[CacheConfig] = None,
  845. quant_config: Optional[QuantizationConfig] = None,
  846. ):
  847. if not hasattr(config, "version"):
  848. if config.hidden_size == 2304 and config.query_num == 64:
  849. version = (2, 0)
  850. else:
  851. version = (2, 5)
  852. else:
  853. version = str(config.version).split(".")
  854. version = tuple([int(x) for x in version])
  855. # Dispatch class based on version
  856. if version == (2, 0):
  857. instance_class = MiniCPMV2
  858. elif version == (2, 5):
  859. instance_class = MiniCPMV2_5
  860. else:
  861. instance_class = MiniCPMVQwen2
  862. return instance_class(config, multimodal_config, cache_config,
  863. quant_config)