1
0

internvl.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. # adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
  2. # --------------------------------------------------------
  3. # InternVL
  4. # Copyright (c) 2023 OpenGVLab
  5. # Licensed under The MIT License [see LICENSE for details]
  6. # --------------------------------------------------------
  7. import itertools
  8. import re
  9. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  10. TypedDict, Union)
  11. import torch
  12. import torch.nn as nn
  13. import torchvision.transforms as T
  14. from PIL import Image
  15. from transformers import PretrainedConfig
  16. from aphrodite.attention import AttentionMetadata
  17. from aphrodite.common.config import CacheConfig, MultiModalConfig
  18. from aphrodite.common.sequence import IntermediateTensors
  19. from aphrodite.common.utils import is_list_of
  20. from aphrodite.distributed import get_pp_group
  21. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  22. from aphrodite.modeling.layers.sampler import SamplerOutput
  23. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  24. from aphrodite.modeling.models.intern_vit import InternVisionModel
  25. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  26. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  27. from aphrodite.multimodal.base import MultiModalInputs
  28. from aphrodite.multimodal.utils import cached_get_tokenizer
  29. from aphrodite.quantization import QuantizationConfig
  30. from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
  31. get_clip_num_patches)
  32. from .interfaces import SupportsMultiModal
  33. from .utils import (filter_weights, flatten_bn,
  34. init_aphrodite_registered_model,
  35. merge_multimodal_embeddings)
  36. IMG_START = '<img>'
  37. IMG_END = '</img>'
  38. IMG_CONTEXT = '<IMG_CONTEXT>'
  39. IMAGENET_MEAN = (0.485, 0.456, 0.406)
  40. IMAGENET_STD = (0.229, 0.224, 0.225)
  41. class InternVLImagePixelInputs(TypedDict):
  42. type: Literal["pixel_values"]
  43. data: torch.Tensor
  44. """
  45. Shape:
  46. `(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
  47. """
  48. class InternVLImageEmbeddingInputs(TypedDict):
  49. type: Literal["image_embeds"]
  50. data: torch.Tensor
  51. """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
  52. `hidden_size` must match the hidden size of language model backbone.
  53. """
  54. InternVLImageInputs = Union[InternVLImagePixelInputs,
  55. InternVLImageEmbeddingInputs]
  56. # copied from https://huggingface.co/OpenGVLab/InternVL2-1B
  57. def build_transform(input_size):
  58. MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
  59. transform = T.Compose([
  60. T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
  61. T.Resize((input_size, input_size),
  62. interpolation=T.InterpolationMode.BICUBIC),
  63. T.ToTensor(),
  64. T.Normalize(mean=MEAN, std=STD)
  65. ])
  66. return transform
  67. # copied from https://huggingface.co/OpenGVLab/InternVL2-1B
  68. def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
  69. image_size):
  70. best_ratio_diff = float('inf')
  71. best_ratio = (1, 1)
  72. area = width * height
  73. for ratio in target_ratios:
  74. target_aspect_ratio = ratio[0] / ratio[1]
  75. ratio_diff = abs(aspect_ratio - target_aspect_ratio)
  76. if ratio_diff < best_ratio_diff:
  77. best_ratio_diff = ratio_diff
  78. best_ratio = ratio
  79. elif ratio_diff == best_ratio_diff:
  80. if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
  81. best_ratio = ratio
  82. return best_ratio
  83. def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
  84. max_num: int, image_size: int,
  85. use_thumbnail: bool) -> Tuple[int, int, int]:
  86. aspect_ratio = orig_width / orig_height
  87. # calculate the existing image aspect ratio
  88. target_ratios = set((i, j) for n in range(min_num, max_num + 1)
  89. for i in range(1, n + 1) for j in range(1, n + 1)
  90. if i * j <= max_num and i * j >= min_num)
  91. target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
  92. # find the closest aspect ratio to the target
  93. target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
  94. target_ratios, orig_width,
  95. orig_height, image_size)
  96. # calculate the target width and height
  97. target_width = image_size * target_aspect_ratio[0]
  98. target_height = image_size * target_aspect_ratio[1]
  99. blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
  100. # add thumbnail image if num_blocks > 1
  101. if use_thumbnail and blocks > 1:
  102. blocks += 1
  103. return blocks, target_width, target_height
  104. # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
  105. def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
  106. image_size: int,
  107. use_thumbnail: bool) -> List[Image.Image]:
  108. orig_width, orig_height = image.size
  109. # calculate the number of blocks without thumbnail
  110. blocks, target_width, target_height = calculate_num_blocks(
  111. orig_width,
  112. orig_height,
  113. min_num,
  114. max_num,
  115. image_size,
  116. use_thumbnail=False)
  117. # resize the image
  118. resized_img = image.resize((target_width, target_height))
  119. processed_images = []
  120. for i in range(blocks):
  121. box = ((i % (target_width // image_size)) * image_size,
  122. (i // (target_width // image_size)) * image_size,
  123. ((i % (target_width // image_size)) + 1) * image_size,
  124. ((i // (target_width // image_size)) + 1) * image_size)
  125. # split the image
  126. split_img = resized_img.crop(box)
  127. processed_images.append(split_img)
  128. assert len(processed_images) == blocks
  129. if use_thumbnail and len(processed_images) != 1:
  130. thumbnail_img = image.resize((image_size, image_size))
  131. processed_images.append(thumbnail_img)
  132. return processed_images
  133. # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
  134. def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
  135. max_num: int, use_thumbnail: bool) -> torch.Tensor:
  136. transform = build_transform(input_size=input_size)
  137. images = dynamic_preprocess(image,
  138. min_num=min_num,
  139. max_num=max_num,
  140. image_size=input_size,
  141. use_thumbnail=use_thumbnail)
  142. pixel_values = [transform(image) for image in images]
  143. pixel_values = torch.stack(pixel_values)
  144. return pixel_values
  145. def get_internvl_num_patches(image_size: int, patch_size: int,
  146. downsample_ratio: float):
  147. return int(
  148. get_clip_num_patches(image_size=image_size, patch_size=patch_size) *
  149. (downsample_ratio**2))
  150. def get_max_internvl_image_tokens(ctx: InputContext):
  151. hf_config = ctx.get_hf_config(PretrainedConfig)
  152. vision_config = hf_config.vision_config
  153. use_thumbnail = hf_config.use_thumbnail
  154. max_dynamic_patch = hf_config.max_dynamic_patch
  155. if use_thumbnail:
  156. max_dynamic_patch += 1
  157. downsample_ratio = hf_config.downsample_ratio
  158. image_size = vision_config.image_size
  159. patch_size = vision_config.patch_size
  160. num_patches = get_internvl_num_patches(image_size, patch_size,
  161. downsample_ratio)
  162. return num_patches * max_dynamic_patch
  163. def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
  164. multi_modal_data = llm_inputs.get("multi_modal_data")
  165. if multi_modal_data is None or "image" not in multi_modal_data:
  166. return llm_inputs
  167. model_config = ctx.model_config
  168. hf_config = ctx.get_hf_config()
  169. vision_config = hf_config.vision_config
  170. image_size = vision_config.image_size
  171. patch_size = vision_config.patch_size
  172. downsample_ratio = hf_config.downsample_ratio
  173. num_patches = get_internvl_num_patches(image_size, patch_size,
  174. downsample_ratio)
  175. image_data = multi_modal_data["image"]
  176. min_num = hf_config.min_dynamic_patch
  177. max_num = hf_config.max_dynamic_patch
  178. use_thumbnail = hf_config.use_thumbnail
  179. if isinstance(image_data, Image.Image):
  180. width, height = image_data.size
  181. num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
  182. max_num, image_size,
  183. use_thumbnail)
  184. image_feature_size = [num_blocks * num_patches]
  185. elif is_list_of(image_data, Image.Image):
  186. image_feature_size = []
  187. for image in image_data:
  188. width, height = image.size
  189. num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
  190. max_num, image_size,
  191. use_thumbnail)
  192. image_feature_size.append(num_blocks * num_patches)
  193. elif isinstance(image_data, torch.Tensor):
  194. num_images, image_feature_size, hidden_size = image_data.shape
  195. else:
  196. raise TypeError(f"Invalid image type: {type(image_data)}")
  197. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  198. trust_remote_code=True)
  199. prompt = llm_inputs.get("prompt")
  200. prompt_token_ids = llm_inputs["prompt_token_ids"]
  201. if prompt is None:
  202. prompt = tokenizer.decode(prompt_token_ids)
  203. new_prompt = prompt
  204. image_idx = sorted(map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))
  205. for idx, feature_size in enumerate(image_feature_size, start=1):
  206. image_prompt = IMG_START + IMG_CONTEXT * feature_size + IMG_END
  207. if not image_idx:
  208. image_prompt = f"Image-{idx}: {image_prompt}"
  209. new_prompt = new_prompt.replace('<image>', image_prompt, 1)
  210. new_prompt_token_ids = tokenizer.encode(new_prompt)
  211. return LLMInputs(prompt=prompt,
  212. prompt_token_ids=new_prompt_token_ids,
  213. multi_modal_data=multi_modal_data)
  214. def input_mapper_for_internvl(ctx: InputContext, data: object):
  215. hf_config = ctx.get_hf_config()
  216. use_thumbnail = hf_config.use_thumbnail
  217. min_num = hf_config.min_dynamic_patch
  218. max_num = hf_config.max_dynamic_patch
  219. image_size = hf_config.vision_config.image_size
  220. if isinstance(data, Image.Image):
  221. data = image_to_pixel_values(data,
  222. image_size,
  223. min_num,
  224. max_num,
  225. use_thumbnail=use_thumbnail)
  226. # Add an N dimension for number of images per prompt (currently 1).
  227. data = data.unsqueeze(0)
  228. elif is_list_of(data, Image.Image):
  229. # we can't stack here because the images may have different num_patches
  230. data = [
  231. image_to_pixel_values(img,
  232. image_size,
  233. min_num,
  234. max_num,
  235. use_thumbnail=use_thumbnail) for img in data
  236. ]
  237. model_config = ctx.model_config
  238. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  239. trust_remote_code=True)
  240. image_token_id = tokenizer.encode(IMG_CONTEXT,
  241. add_special_tokens=False,
  242. return_tensors="pt")[0]
  243. return MultiModalInputs({
  244. "pixel_values": data,
  245. "image_token_id": image_token_id
  246. })
  247. def dummy_data_for_internvl(ctx: InputContext, seq_len: int,
  248. mm_counts: Mapping[str, int]):
  249. num_images = mm_counts["image"]
  250. image_feature_size = get_max_internvl_image_tokens(ctx)
  251. model_config = ctx.model_config
  252. hf_config = ctx.get_hf_config()
  253. vision_config = hf_config.vision_config
  254. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  255. trust_remote_code=True)
  256. seq_data = dummy_seq_data_for_clip(
  257. vision_config,
  258. seq_len,
  259. num_images,
  260. image_token_id=tokenizer.encode(IMG_CONTEXT,
  261. add_special_tokens=False)[0],
  262. image_feature_size_override=image_feature_size,
  263. )
  264. image_size = vision_config.image_size
  265. min_num = hf_config.min_dynamic_patch
  266. max_num = hf_config.max_dynamic_patch
  267. max_image_width = max_num * image_size
  268. max_image_height = min_num * image_size
  269. mm_data = dummy_image_for_clip(
  270. vision_config,
  271. num_images,
  272. image_width_override=max_image_width,
  273. image_height_override=max_image_height,
  274. )
  275. return seq_data, mm_data
  276. @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_internvl)
  277. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
  278. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
  279. @INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
  280. class InternVLChatModel(nn.Module, SupportsMultiModal):
  281. def __init__(self,
  282. config: PretrainedConfig,
  283. multimodal_config: MultiModalConfig,
  284. cache_config: Optional[CacheConfig] = None,
  285. quant_config: Optional[QuantizationConfig] = None) -> None:
  286. super().__init__()
  287. self.config = config
  288. self.multimodal_config = multimodal_config
  289. image_size = config.force_image_size or config.vision_config.image_size
  290. patch_size = config.vision_config.patch_size
  291. self.patch_size = patch_size
  292. self.select_layer = config.select_layer
  293. self.num_image_token = int(
  294. (image_size // patch_size)**2 * (config.downsample_ratio**2))
  295. self.downsample_ratio = config.downsample_ratio
  296. self.ps_version = config.ps_version
  297. vision_feature_layer = self.select_layer
  298. if vision_feature_layer < 0:
  299. num_hidden_layers = config.vision_config.num_hidden_layers \
  300. + vision_feature_layer + 1
  301. else:
  302. num_hidden_layers = vision_feature_layer + 1
  303. self.vision_model = InternVisionModel(
  304. config.vision_config, num_hidden_layers_override=num_hidden_layers)
  305. self.language_model = init_aphrodite_registered_model(
  306. config.text_config, cache_config, quant_config)
  307. vit_hidden_size = config.vision_config.hidden_size
  308. llm_hidden_size = config.text_config.hidden_size
  309. self.mlp1 = nn.Sequential(
  310. nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
  311. nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
  312. llm_hidden_size), nn.GELU(),
  313. nn.Linear(llm_hidden_size, llm_hidden_size))
  314. self.img_context_token_id = None
  315. self.make_empty_intermediate_tensors = (
  316. self.language_model.make_empty_intermediate_tensors)
  317. def pixel_shuffle(self, x, scale_factor=0.5):
  318. n, w, h, c = x.size()
  319. # N, W, H, C --> N, W, H * scale, C // scale
  320. x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
  321. # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
  322. x = x.permute(0, 2, 1, 3).contiguous()
  323. x = x.view(n, int(h * scale_factor), int(w * scale_factor),
  324. int(c / (scale_factor * scale_factor)))
  325. if self.ps_version == 'v1':
  326. pass
  327. else:
  328. x = x.permute(0, 2, 1, 3).contiguous()
  329. return x
  330. def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
  331. vit_embeds = self.vision_model(pixel_values=pixel_values)
  332. vit_embeds = vit_embeds[:, 1:, :]
  333. h = w = int(vit_embeds.shape[1]**0.5)
  334. vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
  335. vit_embeds = self.pixel_shuffle(vit_embeds,
  336. scale_factor=self.downsample_ratio)
  337. vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
  338. vit_embeds.shape[-1])
  339. vit_embeds = self.mlp1(vit_embeds)
  340. return vit_embeds
  341. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  342. h = w = self.config.vision_config.image_size
  343. expected_dims = (3, h, w)
  344. def _validate_shape(d: torch.Tensor):
  345. actual_dims = tuple(d.shape)
  346. if actual_dims != expected_dims:
  347. expected_expr = str(expected_dims)
  348. raise ValueError(
  349. "The expected shape of pixel values per image per batch "
  350. f" per patch is {expected_expr}. "
  351. f"You supplied {tuple(d.shape)}.")
  352. for d in data:
  353. _validate_shape(d)
  354. return data
  355. def _parse_and_validate_image_input(
  356. self, **kwargs: object) -> Optional[InternVLImageInputs]:
  357. pixel_values = kwargs.pop("pixel_values", None)
  358. image_token_id = kwargs.pop("image_token_id", None)
  359. image_embeds = kwargs.pop("image_embeds", None)
  360. if pixel_values is None and image_embeds is None:
  361. return None
  362. if image_embeds is not None:
  363. if not isinstance(image_embeds, torch.Tensor):
  364. raise ValueError("Incorrect type of image embeddings. "
  365. f"Got type: {type(image_embeds)}")
  366. return InternVLImageEmbeddingInputs(
  367. type="image_embeds",
  368. data=flatten_bn(image_embeds),
  369. )
  370. self.img_context_token_id = image_token_id[0]
  371. if pixel_values is not None:
  372. if not isinstance(pixel_values, (torch.Tensor, list)):
  373. raise ValueError("Incorrect type of pixel values. "
  374. f"Got type: {type(pixel_values)}")
  375. # We need to flatten (B, N, P) to (B*N*P),
  376. # so we call flatten_bn twice.
  377. return InternVLImagePixelInputs(
  378. type="pixel_values",
  379. data=self._validate_pixel_values(
  380. flatten_bn(flatten_bn(pixel_values), concat=True)),
  381. )
  382. raise AssertionError("This line should be unreachable.")
  383. def _process_image_input(
  384. self,
  385. image_input: InternVLImageInputs,
  386. ) -> torch.Tensor:
  387. if image_input["type"] == "image_embeds":
  388. return image_input["data"]
  389. assert self.vision_model is not None
  390. image_embeds = self.extract_feature(image_input["data"])
  391. return image_embeds
  392. def forward(
  393. self,
  394. input_ids: torch.Tensor,
  395. positions: torch.Tensor,
  396. kv_caches: List[torch.Tensor],
  397. attn_metadata: AttentionMetadata,
  398. intermediate_tensors: Optional[IntermediateTensors] = None,
  399. **kwargs: object,
  400. ) -> SamplerOutput:
  401. image_input = self._parse_and_validate_image_input(**kwargs)
  402. if image_input is not None and get_pp_group().is_first_rank:
  403. inputs_embeds = self.language_model.model.get_input_embeddings(
  404. input_ids)
  405. vision_embeddings = self._process_image_input(image_input)
  406. inputs_embeds = merge_multimodal_embeddings(
  407. input_ids, inputs_embeds, vision_embeddings,
  408. self.img_context_token_id)
  409. input_ids = None
  410. else:
  411. inputs_embeds = None
  412. hidden_states = self.language_model.model(input_ids,
  413. positions,
  414. kv_caches,
  415. attn_metadata,
  416. intermediate_tensors,
  417. inputs_embeds=inputs_embeds)
  418. return hidden_states
  419. def compute_logits(
  420. self,
  421. hidden_states: torch.Tensor,
  422. sampling_metadata: SamplingMetadata,
  423. ) -> Optional[torch.Tensor]:
  424. return self.language_model.compute_logits(hidden_states,
  425. sampling_metadata)
  426. def sample(
  427. self,
  428. logits: torch.Tensor,
  429. sampling_metadata: SamplingMetadata,
  430. ) -> Optional[SamplerOutput]:
  431. return self.language_model.sample(logits, sampling_metadata)
  432. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  433. # prepare weight iterators for components
  434. vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
  435. # load vision encoder
  436. vit_weights = filter_weights(vit_weights, "vision_model")
  437. self.vision_model.load_weights(vit_weights)
  438. # load mlp projector
  439. mlp_weights = filter_weights(mlp_weights, "mlp1")
  440. mlp_params_dict = dict(self.mlp1.named_parameters())
  441. for name, loaded_weight in mlp_weights:
  442. param = mlp_params_dict[name]
  443. weight_loader = getattr(param, "weight_loader",
  444. default_weight_loader)
  445. weight_loader(param, loaded_weight)
  446. # load llm backbone
  447. llm_weights = filter_weights(llm_weights, "language_model")
  448. self.language_model.load_weights(llm_weights)