internvl.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. # adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
  2. # --------------------------------------------------------
  3. # InternVL
  4. # Copyright (c) 2023 OpenGVLab
  5. # Licensed under The MIT License [see LICENSE for details]
  6. # --------------------------------------------------------
  7. import itertools
  8. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  9. TypedDict, Union)
  10. import torch
  11. import torch.nn as nn
  12. import torchvision.transforms as T
  13. from PIL import Image
  14. from transformers import PretrainedConfig
  15. from aphrodite.attention import AttentionMetadata
  16. from aphrodite.common.config import CacheConfig, MultiModalConfig
  17. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  18. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  19. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  20. from aphrodite.modeling.models.intern_vit import InternVisionModel
  21. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  22. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  23. from aphrodite.multimodal.base import MultiModalInputs
  24. from aphrodite.multimodal.image import cached_get_tokenizer
  25. from aphrodite.quantization import QuantizationConfig
  26. from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
  27. get_clip_num_patches)
  28. from .interfaces import SupportsMultiModal
  29. from .utils import (filter_weights, init_aphrodite_registered_model,
  30. merge_multimodal_embeddings)
  31. IMG_START = '<img>'
  32. IMG_END = '</img>'
  33. IMG_CONTEXT = '<IMG_CONTEXT>'
  34. IMAGENET_MEAN = (0.485, 0.456, 0.406)
  35. IMAGENET_STD = (0.229, 0.224, 0.225)
  36. class InternVLImagePixelInputs(TypedDict):
  37. type: Literal["pixel_values"]
  38. data: Union[torch.Tensor, List[torch.Tensor]]
  39. """
  40. Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
  41. Note that `num_patches` may be different for each batch, in which case
  42. the data is passed as a list instead of a batched tensor.
  43. """
  44. class InternVLImageEmbeddingInputs(TypedDict):
  45. type: Literal["image_embeds"]
  46. data: Union[torch.Tensor, List[torch.Tensor]]
  47. """Shape: `(batch_size, image_feature_size, hidden_size)`
  48. `hidden_size` must match the hidden size of language model backbone.
  49. """
  50. InternVLImageInputs = Union[InternVLImagePixelInputs,
  51. InternVLImageEmbeddingInputs]
  52. # copied from https://huggingface.co/OpenGVLab/InternVL2-1B
  53. def build_transform(input_size):
  54. MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
  55. transform = T.Compose([
  56. T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
  57. T.Resize((input_size, input_size),
  58. interpolation=T.InterpolationMode.BICUBIC),
  59. T.ToTensor(),
  60. T.Normalize(mean=MEAN, std=STD)
  61. ])
  62. return transform
  63. # copied from https://huggingface.co/OpenGVLab/InternVL2-1B
  64. def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
  65. image_size):
  66. best_ratio_diff = float('inf')
  67. best_ratio = (1, 1)
  68. area = width * height
  69. for ratio in target_ratios:
  70. target_aspect_ratio = ratio[0] / ratio[1]
  71. ratio_diff = abs(aspect_ratio - target_aspect_ratio)
  72. if ratio_diff < best_ratio_diff:
  73. best_ratio_diff = ratio_diff
  74. best_ratio = ratio
  75. elif ratio_diff == best_ratio_diff:
  76. if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
  77. best_ratio = ratio
  78. return best_ratio
  79. def calculate_num_blocks(orig_width: int, orig_height: int,
  80. min_num: int, max_num: int,
  81. image_size: int) -> Tuple[int, int, int]:
  82. aspect_ratio = orig_width / orig_height
  83. # calculate the existing image aspect ratio
  84. target_ratios = set((i, j) for n in range(min_num, max_num + 1)
  85. for i in range(1, n + 1) for j in range(1, n + 1)
  86. if i * j <= max_num and i * j >= min_num)
  87. target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
  88. # find the closest aspect ratio to the target
  89. target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
  90. target_ratios, orig_width,
  91. orig_height, image_size)
  92. # calculate the target width and height
  93. target_width = image_size * target_aspect_ratio[0]
  94. target_height = image_size * target_aspect_ratio[1]
  95. blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
  96. return blocks, target_width, target_height
  97. # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
  98. def dynamic_preprocess(image: Image.Image, min_num: int,
  99. max_num: int, image_size: int,
  100. use_thumbnail: int) -> List[Image.Image]:
  101. orig_width, orig_height = image.size
  102. blocks, target_width, target_height = calculate_num_blocks(
  103. orig_width, orig_height, min_num, max_num, image_size)
  104. # resize the image
  105. resized_img = image.resize((target_width, target_height))
  106. processed_images = []
  107. for i in range(blocks):
  108. box = ((i % (target_width // image_size)) * image_size,
  109. (i // (target_width // image_size)) * image_size,
  110. ((i % (target_width // image_size)) + 1) * image_size,
  111. ((i // (target_width // image_size)) + 1) * image_size)
  112. # split the image
  113. split_img = resized_img.crop(box)
  114. processed_images.append(split_img)
  115. assert len(processed_images) == blocks
  116. if use_thumbnail and len(processed_images) != 1:
  117. thumbnail_img = image.resize((image_size, image_size))
  118. processed_images.append(thumbnail_img)
  119. return processed_images
  120. # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
  121. def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
  122. max_num: int, use_thumbnail: bool) -> torch.Tensor:
  123. transform = build_transform(input_size=input_size)
  124. images = dynamic_preprocess(image,
  125. min_num=min_num,
  126. max_num=max_num,
  127. image_size=input_size,
  128. use_thumbnail=use_thumbnail)
  129. pixel_values = [transform(image) for image in images]
  130. pixel_values = torch.stack(pixel_values)
  131. return pixel_values
  132. def get_internvl_num_patches(image_size: int, patch_size: int,
  133. downsample_ratio: float):
  134. return int(
  135. get_clip_num_patches(image_size=image_size, patch_size=patch_size) *
  136. (downsample_ratio**2))
  137. def get_max_internvl_image_tokens(ctx: InputContext):
  138. hf_config = ctx.get_hf_config(PretrainedConfig)
  139. vision_config = hf_config.vision_config
  140. use_thumbnail = hf_config.use_thumbnail
  141. max_dynamic_patch = hf_config.max_dynamic_patch
  142. if use_thumbnail:
  143. max_dynamic_patch += 1
  144. downsample_ratio = hf_config.downsample_ratio
  145. image_size = vision_config.image_size
  146. patch_size = vision_config.patch_size
  147. num_patches = get_internvl_num_patches(image_size, patch_size,
  148. downsample_ratio)
  149. return num_patches * max_dynamic_patch
  150. def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
  151. multi_modal_data = llm_inputs.get("multi_modal_data")
  152. if multi_modal_data is None or "image" not in multi_modal_data:
  153. return llm_inputs
  154. model_config = ctx.model_config
  155. hf_config = ctx.get_hf_config(PretrainedConfig)
  156. vision_config = hf_config.vision_config
  157. image_size = vision_config.image_size
  158. patch_size = vision_config.patch_size
  159. downsample_ratio = hf_config.downsample_ratio
  160. num_patches = get_internvl_num_patches(image_size, patch_size,
  161. downsample_ratio)
  162. image_data = multi_modal_data["image"]
  163. if isinstance(image_data, Image.Image):
  164. width, height = image_data.size
  165. min_num = hf_config.min_dynamic_patch
  166. max_num = hf_config.max_dynamic_patch
  167. num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
  168. max_num, image_size)
  169. # add thumbnail image if num_blocks > 1
  170. if hf_config.use_thumbnail and num_blocks > 1:
  171. num_blocks += 1
  172. image_feature_size = num_blocks * num_patches
  173. elif isinstance(image_data, torch.Tensor):
  174. image_feature_size = image_data.shape[0]
  175. else:
  176. raise TypeError(f"Invalid image type: {type(image_data)}")
  177. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  178. trust_remote_code=True)
  179. prompt = llm_inputs.get("prompt")
  180. prompt_token_ids = llm_inputs["prompt_token_ids"]
  181. if prompt is None:
  182. prompt = tokenizer.decode(prompt_token_ids)
  183. image_prompt = IMG_START + IMG_CONTEXT * image_feature_size + IMG_END
  184. new_prompt = prompt.replace('<image>', image_prompt, 1)
  185. new_prompt_token_ids = tokenizer.encode(new_prompt)
  186. return LLMInputs(prompt=prompt,
  187. prompt_token_ids=new_prompt_token_ids,
  188. multi_modal_data=multi_modal_data)
  189. def input_mapper_for_internvl(ctx: InputContext, data: object):
  190. hf_config = ctx.get_hf_config()
  191. use_thumbnail = hf_config.use_thumbnail
  192. min_num = hf_config.min_dynamic_patch
  193. max_num = hf_config.max_dynamic_patch
  194. image_size = hf_config.vision_config.image_size
  195. if isinstance(data, Image.Image):
  196. data = image_to_pixel_values(data,
  197. image_size,
  198. min_num,
  199. max_num,
  200. use_thumbnail=use_thumbnail)
  201. model_config = ctx.model_config
  202. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  203. trust_remote_code=True)
  204. image_token_id = tokenizer.encode(IMG_CONTEXT,
  205. add_special_tokens=False,
  206. return_tensors="pt")[0]
  207. return MultiModalInputs({
  208. "pixel_values": data,
  209. "image_token_id": image_token_id
  210. })
  211. def dummy_data_for_internvl(ctx: InputContext, seq_len: int,
  212. mm_counts: Mapping[str, int]):
  213. num_images = mm_counts["image"]
  214. image_feature_size = get_max_internvl_image_tokens(ctx)
  215. model_config = ctx.model_config
  216. hf_config = ctx.get_hf_config()
  217. vision_config = hf_config.vision_config
  218. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  219. trust_remote_code=True)
  220. seq_data = dummy_seq_data_for_clip(
  221. vision_config,
  222. seq_len,
  223. num_images,
  224. image_token_id=tokenizer.encode(IMG_CONTEXT,
  225. add_special_tokens=False)[0],
  226. image_feature_size_override=image_feature_size,
  227. )
  228. image_size = vision_config.image_size
  229. min_num = hf_config.min_dynamic_patch
  230. max_num = hf_config.max_dynamic_patch
  231. max_image_width = max_num * image_size
  232. max_image_height = min_num * image_size
  233. mm_data = dummy_image_for_clip(
  234. vision_config,
  235. num_images,
  236. image_width_override=max_image_width,
  237. image_height_override=max_image_height,
  238. )
  239. return seq_data, mm_data
  240. @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_internvl)
  241. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
  242. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
  243. @INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
  244. class InternVLChatModel(nn.Module, SupportsMultiModal):
  245. def __init__(self,
  246. config: PretrainedConfig,
  247. multimodal_config: MultiModalConfig,
  248. cache_config: Optional[CacheConfig] = None,
  249. quant_config: Optional[QuantizationConfig] = None) -> None:
  250. super().__init__()
  251. self.config = config
  252. self.multimodal_config = multimodal_config
  253. image_size = config.force_image_size or config.vision_config.image_size
  254. patch_size = config.vision_config.patch_size
  255. self.patch_size = patch_size
  256. self.select_layer = config.select_layer
  257. self.num_image_token = int(
  258. (image_size // patch_size)**2 * (config.downsample_ratio**2))
  259. self.downsample_ratio = config.downsample_ratio
  260. self.ps_version = config.ps_version
  261. vision_feature_layer = self.select_layer
  262. if vision_feature_layer < 0:
  263. num_hidden_layers = config.vision_config.num_hidden_layers \
  264. + vision_feature_layer + 1
  265. else:
  266. num_hidden_layers = vision_feature_layer + 1
  267. self.vision_model = InternVisionModel(
  268. config.vision_config, num_hidden_layers_override=num_hidden_layers)
  269. self.language_model = init_aphrodite_registered_model(
  270. config.text_config, cache_config, quant_config)
  271. vit_hidden_size = config.vision_config.hidden_size
  272. llm_hidden_size = config.text_config.hidden_size
  273. self.mlp1 = nn.Sequential(
  274. nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
  275. nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
  276. llm_hidden_size), nn.GELU(),
  277. nn.Linear(llm_hidden_size, llm_hidden_size))
  278. self.img_context_token_id = None
  279. def pixel_shuffle(self, x, scale_factor=0.5):
  280. n, w, h, c = x.size()
  281. # N, W, H, C --> N, W, H * scale, C // scale
  282. x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
  283. # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
  284. x = x.permute(0, 2, 1, 3).contiguous()
  285. x = x.view(n, int(h * scale_factor), int(w * scale_factor),
  286. int(c / (scale_factor * scale_factor)))
  287. if self.ps_version == 'v1':
  288. pass
  289. else:
  290. x = x.permute(0, 2, 1, 3).contiguous()
  291. return x
  292. def extract_feature(self, pixel_values):
  293. vit_embeds = self.vision_model(pixel_values=pixel_values)
  294. vit_embeds = vit_embeds[:, 1:, :]
  295. h = w = int(vit_embeds.shape[1]**0.5)
  296. vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
  297. vit_embeds = self.pixel_shuffle(vit_embeds,
  298. scale_factor=self.downsample_ratio)
  299. vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
  300. vit_embeds.shape[-1])
  301. vit_embeds = self.mlp1(vit_embeds)
  302. return vit_embeds
  303. def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
  304. if list(data.shape[1:]) != [2]:
  305. raise ValueError(
  306. f"The expected image sizes shape is batch dimension plus "
  307. f"{[2]}. You supplied {data.shape}.")
  308. return data
  309. def _validate_pixel_values(
  310. self, data: Union[torch.Tensor, List[torch.Tensor]]
  311. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  312. h = w = self.config.vision_config.image_size
  313. expected_dims = (3, h, w)
  314. def _validate_shape(d: torch.Tensor):
  315. actual_dims = tuple(d.shape)
  316. if actual_dims != expected_dims:
  317. expected_expr = ("num_patches", *map(str, expected_dims))
  318. raise ValueError(
  319. "The expected shape of pixel values in each batch element "
  320. f"is {expected_expr}. You supplied {tuple(d.shape)}.")
  321. for d in data:
  322. _validate_shape(d)
  323. return data
  324. def _parse_and_validate_image_input(
  325. self, **kwargs: object) -> Optional[InternVLImageInputs]:
  326. pixel_values = kwargs.pop("pixel_values", None)
  327. image_token_id = kwargs.pop("image_token_id", None)
  328. image_embeds = kwargs.pop("image_embeds", None)
  329. if pixel_values is None and image_embeds is None:
  330. return None
  331. if image_embeds is not None:
  332. if not isinstance(image_embeds, torch.Tensor):
  333. raise ValueError("Incorrect type of image embeddings. "
  334. f"Got type: {type(image_embeds)}")
  335. return InternVLImageEmbeddingInputs(
  336. type="image_embeds",
  337. data=image_embeds,
  338. )
  339. self.img_context_token_id = image_token_id[0]
  340. if pixel_values is not None:
  341. if not isinstance(pixel_values, (torch.Tensor, list)):
  342. raise ValueError("Incorrect type of pixel values. "
  343. f"Got type: {type(pixel_values)}")
  344. return InternVLImagePixelInputs(
  345. type="pixel_values",
  346. data=self._validate_pixel_values(pixel_values),
  347. )
  348. raise AssertionError("This line should be unreachable.")
  349. def _process_image_input(
  350. self,
  351. image_input: InternVLImageInputs,
  352. ) -> torch.Tensor:
  353. if image_input["type"] == "image_embeds":
  354. return image_input["data"]
  355. assert self.vision_model is not None
  356. image_embeds = self.extract_feature(image_input["data"])
  357. return image_embeds
  358. def forward(
  359. self,
  360. input_ids: torch.Tensor,
  361. positions: torch.Tensor,
  362. kv_caches: List[torch.Tensor],
  363. attn_metadata: AttentionMetadata,
  364. intermediate_tensors: Optional[IntermediateTensors] = None,
  365. **kwargs: object,
  366. ) -> SamplerOutput:
  367. image_input = self._parse_and_validate_image_input(**kwargs)
  368. if image_input is not None:
  369. inputs_embeds = self.language_model.model.get_input_embeddings(
  370. input_ids)
  371. vision_embeddings = self._process_image_input(image_input)
  372. inputs_embeds = merge_multimodal_embeddings(
  373. input_ids, inputs_embeds, vision_embeddings,
  374. self.img_context_token_id)
  375. input_ids = None
  376. else:
  377. inputs_embeds = None
  378. hidden_states = self.language_model.model(input_ids,
  379. positions,
  380. kv_caches,
  381. attn_metadata,
  382. None,
  383. inputs_embeds=inputs_embeds)
  384. return hidden_states
  385. def compute_logits(
  386. self,
  387. hidden_states: torch.Tensor,
  388. sampling_metadata: SamplingMetadata,
  389. ) -> Optional[torch.Tensor]:
  390. return self.language_model.compute_logits(hidden_states,
  391. sampling_metadata)
  392. def sample(
  393. self,
  394. logits: torch.Tensor,
  395. sampling_metadata: SamplingMetadata,
  396. ) -> Optional[SamplerOutput]:
  397. return self.language_model.sample(logits, sampling_metadata)
  398. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  399. # prepare weight iterators for components
  400. vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
  401. # load vision encoder
  402. vit_weights = filter_weights(vit_weights, "vision_model")
  403. self.vision_model.load_weights(vit_weights)
  404. # load mlp projector
  405. mlp_weights = filter_weights(mlp_weights, "mlp1")
  406. mlp_params_dict = dict(self.mlp1.named_parameters())
  407. for name, loaded_weight in mlp_weights:
  408. param = mlp_params_dict[name]
  409. weight_loader = getattr(param, "weight_loader",
  410. default_weight_loader)
  411. weight_loader(param, loaded_weight)
  412. # load llm backbone
  413. llm_weights = filter_weights(llm_weights, "language_model")
  414. self.language_model.load_weights(llm_weights)