1
0

internvl.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. # adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
  2. # --------------------------------------------------------
  3. # InternVL
  4. # Copyright (c) 2023 OpenGVLab
  5. # Licensed under The MIT License [see LICENSE for details]
  6. # --------------------------------------------------------
  7. import re
  8. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  9. TypedDict, Union)
  10. import torch
  11. import torch.nn as nn
  12. import torchvision.transforms as T
  13. from PIL import Image
  14. from transformers import PretrainedConfig
  15. from aphrodite.attention import AttentionMetadata
  16. from aphrodite.common.config import CacheConfig, MultiModalConfig
  17. from aphrodite.common.sequence import IntermediateTensors
  18. from aphrodite.common.utils import is_list_of
  19. from aphrodite.distributed import get_pp_group
  20. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  21. from aphrodite.modeling.layers.sampler import SamplerOutput
  22. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  23. from aphrodite.modeling.models.intern_vit import InternVisionModel
  24. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  25. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  26. from aphrodite.multimodal.base import MultiModalInputs
  27. from aphrodite.multimodal.utils import cached_get_tokenizer
  28. from aphrodite.quantization import QuantizationConfig
  29. from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
  30. get_clip_num_patches)
  31. from .interfaces import SupportsMultiModal
  32. from .utils import (flatten_bn, group_weights_with_prefix,
  33. init_aphrodite_registered_model,
  34. merge_multimodal_embeddings)
  35. IMG_START = '<img>'
  36. IMG_END = '</img>'
  37. IMG_CONTEXT = '<IMG_CONTEXT>'
  38. IMAGENET_MEAN = (0.485, 0.456, 0.406)
  39. IMAGENET_STD = (0.229, 0.224, 0.225)
  40. class InternVLImagePixelInputs(TypedDict):
  41. type: Literal["pixel_values"]
  42. data: torch.Tensor
  43. """
  44. Shape:
  45. `(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
  46. """
  47. class InternVLImageEmbeddingInputs(TypedDict):
  48. type: Literal["image_embeds"]
  49. data: torch.Tensor
  50. """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
  51. `hidden_size` must match the hidden size of language model backbone.
  52. """
  53. InternVLImageInputs = Union[InternVLImagePixelInputs,
  54. InternVLImageEmbeddingInputs]
  55. # copied from https://huggingface.co/OpenGVLab/InternVL2-1B
  56. def build_transform(input_size):
  57. MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
  58. transform = T.Compose([
  59. T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
  60. T.Resize((input_size, input_size),
  61. interpolation=T.InterpolationMode.BICUBIC),
  62. T.ToTensor(),
  63. T.Normalize(mean=MEAN, std=STD)
  64. ])
  65. return transform
  66. # copied from https://huggingface.co/OpenGVLab/InternVL2-1B
  67. def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
  68. image_size):
  69. best_ratio_diff = float('inf')
  70. best_ratio = (1, 1)
  71. area = width * height
  72. for ratio in target_ratios:
  73. target_aspect_ratio = ratio[0] / ratio[1]
  74. ratio_diff = abs(aspect_ratio - target_aspect_ratio)
  75. if ratio_diff < best_ratio_diff:
  76. best_ratio_diff = ratio_diff
  77. best_ratio = ratio
  78. elif ratio_diff == best_ratio_diff:
  79. if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
  80. best_ratio = ratio
  81. return best_ratio
  82. def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
  83. max_num: int, image_size: int,
  84. use_thumbnail: bool) -> Tuple[int, int, int]:
  85. aspect_ratio = orig_width / orig_height
  86. # calculate the existing image aspect ratio
  87. target_ratios = set((i, j) for n in range(min_num, max_num + 1)
  88. for i in range(1, n + 1) for j in range(1, n + 1)
  89. if i * j <= max_num and i * j >= min_num)
  90. target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
  91. # find the closest aspect ratio to the target
  92. target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
  93. target_ratios, orig_width,
  94. orig_height, image_size)
  95. # calculate the target width and height
  96. target_width = image_size * target_aspect_ratio[0]
  97. target_height = image_size * target_aspect_ratio[1]
  98. blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
  99. # add thumbnail image if num_blocks > 1
  100. if use_thumbnail and blocks > 1:
  101. blocks += 1
  102. return blocks, target_width, target_height
  103. # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
  104. def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
  105. image_size: int,
  106. use_thumbnail: bool) -> List[Image.Image]:
  107. orig_width, orig_height = image.size
  108. # calculate the number of blocks without thumbnail
  109. blocks, target_width, target_height = calculate_num_blocks(
  110. orig_width,
  111. orig_height,
  112. min_num,
  113. max_num,
  114. image_size,
  115. use_thumbnail=False)
  116. # resize the image
  117. resized_img = image.resize((target_width, target_height))
  118. processed_images = []
  119. for i in range(blocks):
  120. box = ((i % (target_width // image_size)) * image_size,
  121. (i // (target_width // image_size)) * image_size,
  122. ((i % (target_width // image_size)) + 1) * image_size,
  123. ((i // (target_width // image_size)) + 1) * image_size)
  124. # split the image
  125. split_img = resized_img.crop(box)
  126. processed_images.append(split_img)
  127. assert len(processed_images) == blocks
  128. if use_thumbnail and len(processed_images) != 1:
  129. thumbnail_img = image.resize((image_size, image_size))
  130. processed_images.append(thumbnail_img)
  131. return processed_images
  132. # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
  133. def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
  134. max_num: int, use_thumbnail: bool) -> torch.Tensor:
  135. transform = build_transform(input_size=input_size)
  136. images = dynamic_preprocess(image,
  137. min_num=min_num,
  138. max_num=max_num,
  139. image_size=input_size,
  140. use_thumbnail=use_thumbnail)
  141. pixel_values = [transform(image) for image in images]
  142. pixel_values = torch.stack(pixel_values)
  143. return pixel_values
  144. def get_internvl_num_patches(image_size: int, patch_size: int,
  145. downsample_ratio: float):
  146. return int(
  147. get_clip_num_patches(image_size=image_size, patch_size=patch_size) *
  148. (downsample_ratio**2))
  149. def get_max_internvl_image_tokens(ctx: InputContext):
  150. hf_config = ctx.get_hf_config(PretrainedConfig)
  151. vision_config = hf_config.vision_config
  152. use_thumbnail = hf_config.use_thumbnail
  153. max_dynamic_patch = hf_config.max_dynamic_patch
  154. if use_thumbnail:
  155. max_dynamic_patch += 1
  156. downsample_ratio = hf_config.downsample_ratio
  157. image_size = vision_config.image_size
  158. patch_size = vision_config.patch_size
  159. num_patches = get_internvl_num_patches(image_size, patch_size,
  160. downsample_ratio)
  161. return num_patches * max_dynamic_patch
  162. def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
  163. multi_modal_data = llm_inputs.get("multi_modal_data")
  164. if multi_modal_data is None or "image" not in multi_modal_data:
  165. return llm_inputs
  166. model_config = ctx.model_config
  167. hf_config = ctx.get_hf_config()
  168. vision_config = hf_config.vision_config
  169. image_size = vision_config.image_size
  170. patch_size = vision_config.patch_size
  171. downsample_ratio = hf_config.downsample_ratio
  172. num_patches = get_internvl_num_patches(image_size, patch_size,
  173. downsample_ratio)
  174. image_data = multi_modal_data["image"]
  175. min_num = hf_config.min_dynamic_patch
  176. max_num = hf_config.max_dynamic_patch
  177. use_thumbnail = hf_config.use_thumbnail
  178. if isinstance(image_data, Image.Image):
  179. width, height = image_data.size
  180. num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
  181. max_num, image_size,
  182. use_thumbnail)
  183. image_feature_size = [num_blocks * num_patches]
  184. elif is_list_of(image_data, Image.Image):
  185. image_feature_size = []
  186. for image in image_data:
  187. width, height = image.size
  188. num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
  189. max_num, image_size,
  190. use_thumbnail)
  191. image_feature_size.append(num_blocks * num_patches)
  192. elif isinstance(image_data, torch.Tensor):
  193. num_images, image_feature_size, hidden_size = image_data.shape
  194. else:
  195. raise TypeError(f"Invalid image type: {type(image_data)}")
  196. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  197. trust_remote_code=True)
  198. prompt = llm_inputs.get("prompt")
  199. prompt_token_ids = llm_inputs["prompt_token_ids"]
  200. if prompt is None:
  201. prompt = tokenizer.decode(prompt_token_ids)
  202. new_prompt = prompt
  203. image_idx = sorted(map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))
  204. for idx, feature_size in enumerate(image_feature_size, start=1):
  205. image_prompt = IMG_START + IMG_CONTEXT * feature_size + IMG_END
  206. if not image_idx:
  207. image_prompt = f"Image-{idx}: {image_prompt}"
  208. new_prompt = new_prompt.replace('<image>', image_prompt, 1)
  209. new_prompt_token_ids = tokenizer.encode(new_prompt)
  210. return LLMInputs(prompt=prompt,
  211. prompt_token_ids=new_prompt_token_ids,
  212. multi_modal_data=multi_modal_data)
  213. def input_mapper_for_internvl(ctx: InputContext, data: object):
  214. hf_config = ctx.get_hf_config()
  215. use_thumbnail = hf_config.use_thumbnail
  216. min_num = hf_config.min_dynamic_patch
  217. max_num = hf_config.max_dynamic_patch
  218. image_size = hf_config.vision_config.image_size
  219. if isinstance(data, Image.Image):
  220. data = image_to_pixel_values(data,
  221. image_size,
  222. min_num,
  223. max_num,
  224. use_thumbnail=use_thumbnail)
  225. # Add an N dimension for number of images per prompt (currently 1).
  226. data = data.unsqueeze(0)
  227. elif is_list_of(data, Image.Image):
  228. # we can't stack here because the images may have different num_patches
  229. data = [
  230. image_to_pixel_values(img,
  231. image_size,
  232. min_num,
  233. max_num,
  234. use_thumbnail=use_thumbnail) for img in data
  235. ]
  236. model_config = ctx.model_config
  237. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  238. trust_remote_code=True)
  239. image_token_id = tokenizer.encode(IMG_CONTEXT,
  240. add_special_tokens=False,
  241. return_tensors="pt")[0]
  242. return MultiModalInputs({
  243. "pixel_values": data,
  244. "image_token_id": image_token_id
  245. })
  246. def dummy_data_for_internvl(ctx: InputContext, seq_len: int,
  247. mm_counts: Mapping[str, int]):
  248. num_images = mm_counts["image"]
  249. image_feature_size = get_max_internvl_image_tokens(ctx)
  250. model_config = ctx.model_config
  251. hf_config = ctx.get_hf_config()
  252. vision_config = hf_config.vision_config
  253. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  254. trust_remote_code=True)
  255. seq_data = dummy_seq_data_for_clip(
  256. vision_config,
  257. seq_len,
  258. num_images,
  259. image_token_id=tokenizer.encode(IMG_CONTEXT,
  260. add_special_tokens=False)[0],
  261. image_feature_size_override=image_feature_size,
  262. )
  263. image_size = vision_config.image_size
  264. min_num = hf_config.min_dynamic_patch
  265. max_num = hf_config.max_dynamic_patch
  266. max_image_width = max_num * image_size
  267. max_image_height = min_num * image_size
  268. mm_data = dummy_image_for_clip(
  269. vision_config,
  270. num_images,
  271. image_width_override=max_image_width,
  272. image_height_override=max_image_height,
  273. )
  274. return seq_data, mm_data
  275. @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_internvl)
  276. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
  277. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
  278. @INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
  279. class InternVLChatModel(nn.Module, SupportsMultiModal):
  280. def __init__(self,
  281. config: PretrainedConfig,
  282. multimodal_config: MultiModalConfig,
  283. cache_config: Optional[CacheConfig] = None,
  284. quant_config: Optional[QuantizationConfig] = None) -> None:
  285. super().__init__()
  286. self.config = config
  287. self.multimodal_config = multimodal_config
  288. image_size = config.force_image_size or config.vision_config.image_size
  289. patch_size = config.vision_config.patch_size
  290. self.patch_size = patch_size
  291. self.select_layer = config.select_layer
  292. self.num_image_token = int(
  293. (image_size // patch_size)**2 * (config.downsample_ratio**2))
  294. self.downsample_ratio = config.downsample_ratio
  295. self.ps_version = config.ps_version
  296. vision_feature_layer = self.select_layer
  297. if vision_feature_layer < 0:
  298. num_hidden_layers = config.vision_config.num_hidden_layers \
  299. + vision_feature_layer + 1
  300. else:
  301. num_hidden_layers = vision_feature_layer + 1
  302. self.vision_model = InternVisionModel(
  303. config.vision_config, num_hidden_layers_override=num_hidden_layers)
  304. self.language_model = init_aphrodite_registered_model(
  305. config.text_config, cache_config, quant_config)
  306. vit_hidden_size = config.vision_config.hidden_size
  307. llm_hidden_size = config.text_config.hidden_size
  308. self.mlp1 = nn.Sequential(
  309. nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
  310. nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
  311. llm_hidden_size), nn.GELU(),
  312. nn.Linear(llm_hidden_size, llm_hidden_size))
  313. self.img_context_token_id = None
  314. self.make_empty_intermediate_tensors = (
  315. self.language_model.make_empty_intermediate_tensors)
  316. def pixel_shuffle(self, x, scale_factor=0.5):
  317. n, w, h, c = x.size()
  318. # N, W, H, C --> N, W, H * scale, C // scale
  319. x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
  320. # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
  321. x = x.permute(0, 2, 1, 3).contiguous()
  322. x = x.view(n, int(h * scale_factor), int(w * scale_factor),
  323. int(c / (scale_factor * scale_factor)))
  324. if self.ps_version == 'v1':
  325. pass
  326. else:
  327. x = x.permute(0, 2, 1, 3).contiguous()
  328. return x
  329. def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
  330. vit_embeds = self.vision_model(pixel_values=pixel_values)
  331. vit_embeds = vit_embeds[:, 1:, :]
  332. h = w = int(vit_embeds.shape[1]**0.5)
  333. vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
  334. vit_embeds = self.pixel_shuffle(vit_embeds,
  335. scale_factor=self.downsample_ratio)
  336. vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
  337. vit_embeds.shape[-1])
  338. vit_embeds = self.mlp1(vit_embeds)
  339. return vit_embeds
  340. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  341. h = w = self.config.vision_config.image_size
  342. expected_dims = (3, h, w)
  343. def _validate_shape(d: torch.Tensor):
  344. actual_dims = tuple(d.shape)
  345. if actual_dims != expected_dims:
  346. expected_expr = str(expected_dims)
  347. raise ValueError(
  348. "The expected shape of pixel values per image per batch "
  349. f" per patch is {expected_expr}. "
  350. f"You supplied {tuple(d.shape)}.")
  351. for d in data:
  352. _validate_shape(d)
  353. return data
  354. def _parse_and_validate_image_input(
  355. self, **kwargs: object) -> Optional[InternVLImageInputs]:
  356. pixel_values = kwargs.pop("pixel_values", None)
  357. image_token_id = kwargs.pop("image_token_id", None)
  358. image_embeds = kwargs.pop("image_embeds", None)
  359. if pixel_values is None and image_embeds is None:
  360. return None
  361. if image_embeds is not None:
  362. if not isinstance(image_embeds, torch.Tensor):
  363. raise ValueError("Incorrect type of image embeddings. "
  364. f"Got type: {type(image_embeds)}")
  365. return InternVLImageEmbeddingInputs(
  366. type="image_embeds",
  367. data=flatten_bn(image_embeds),
  368. )
  369. self.img_context_token_id = image_token_id[0]
  370. if pixel_values is not None:
  371. if not isinstance(pixel_values, (torch.Tensor, list)):
  372. raise ValueError("Incorrect type of pixel values. "
  373. f"Got type: {type(pixel_values)}")
  374. # We need to flatten (B, N, P) to (B*N*P),
  375. # so we call flatten_bn twice.
  376. return InternVLImagePixelInputs(
  377. type="pixel_values",
  378. data=self._validate_pixel_values(
  379. flatten_bn(flatten_bn(pixel_values), concat=True)),
  380. )
  381. raise AssertionError("This line should be unreachable.")
  382. def _process_image_input(
  383. self,
  384. image_input: InternVLImageInputs,
  385. ) -> torch.Tensor:
  386. if image_input["type"] == "image_embeds":
  387. return image_input["data"]
  388. assert self.vision_model is not None
  389. image_embeds = self.extract_feature(image_input["data"])
  390. return image_embeds
  391. def forward(
  392. self,
  393. input_ids: torch.Tensor,
  394. positions: torch.Tensor,
  395. kv_caches: List[torch.Tensor],
  396. attn_metadata: AttentionMetadata,
  397. intermediate_tensors: Optional[IntermediateTensors] = None,
  398. **kwargs: object,
  399. ) -> SamplerOutput:
  400. image_input = self._parse_and_validate_image_input(**kwargs)
  401. if image_input is not None and get_pp_group().is_first_rank:
  402. inputs_embeds = self.language_model.model.get_input_embeddings(
  403. input_ids)
  404. vision_embeddings = self._process_image_input(image_input)
  405. inputs_embeds = merge_multimodal_embeddings(
  406. input_ids, inputs_embeds, vision_embeddings,
  407. self.img_context_token_id)
  408. input_ids = None
  409. else:
  410. inputs_embeds = None
  411. hidden_states = self.language_model.model(input_ids,
  412. positions,
  413. kv_caches,
  414. attn_metadata,
  415. intermediate_tensors,
  416. inputs_embeds=inputs_embeds)
  417. return hidden_states
  418. def compute_logits(
  419. self,
  420. hidden_states: torch.Tensor,
  421. sampling_metadata: SamplingMetadata,
  422. ) -> Optional[torch.Tensor]:
  423. return self.language_model.compute_logits(hidden_states,
  424. sampling_metadata)
  425. def sample(
  426. self,
  427. logits: torch.Tensor,
  428. sampling_metadata: SamplingMetadata,
  429. ) -> Optional[SamplerOutput]:
  430. return self.language_model.sample(logits, sampling_metadata)
  431. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  432. # prepare weight iterators for components
  433. weights_group = group_weights_with_prefix(weights)
  434. # load vision encoder
  435. self.vision_model.load_weights(weights_group["vision_model"])
  436. # load mlp projector
  437. mlp_params_dict = dict(self.mlp1.named_parameters())
  438. for name, loaded_weight in weights_group["mlp1"]:
  439. param = mlp_params_dict[name]
  440. weight_loader = getattr(param, "weight_loader",
  441. default_weight_loader)
  442. weight_loader(param, loaded_weight)
  443. # load llm backbone
  444. self.language_model.load_weights(weights_group["language_model"])