internvl.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. # adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
  2. # --------------------------------------------------------
  3. # InternVL
  4. # Copyright (c) 2023 OpenGVLab
  5. # Licensed under The MIT License [see LICENSE for details]
  6. # --------------------------------------------------------
  7. import itertools
  8. from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
  9. import torch
  10. import torch.nn as nn
  11. import torchvision.transforms as T
  12. from PIL import Image
  13. from transformers import PretrainedConfig
  14. from aphrodite.attention import AttentionMetadata
  15. from aphrodite.common.config import CacheConfig, MultiModalConfig
  16. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  17. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  18. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  19. from aphrodite.modeling.models.intern_vit import InternVisionModel
  20. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  21. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  22. from aphrodite.multimodal.base import MultiModalInputs
  23. from aphrodite.multimodal.image import cached_get_tokenizer
  24. from aphrodite.quantization import QuantizationConfig
  25. from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
  26. get_clip_num_patches)
  27. from .interfaces import SupportsMultiModal
  28. from .utils import (filter_weights, init_aphrodite_registered_model,
  29. merge_multimodal_embeddings)
  30. IMG_START = '<img>'
  31. IMG_END = '</img>'
  32. IMG_CONTEXT = '<IMG_CONTEXT>'
  33. IMAGENET_MEAN = (0.485, 0.456, 0.406)
  34. IMAGENET_STD = (0.229, 0.224, 0.225)
  35. class InternVLImagePixelInputs(TypedDict):
  36. type: Literal["pixel_values"]
  37. data: Union[torch.Tensor, List[torch.Tensor]]
  38. """
  39. Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
  40. Note that `num_patches` may be different for each batch, in which case
  41. the data is passed as a list instead of a batched tensor.
  42. """
  43. class InternVLImageEmbeddingInputs(TypedDict):
  44. type: Literal["image_embeds"]
  45. data: Union[torch.Tensor, List[torch.Tensor]]
  46. """Shape: `(batch_size, image_feature_size, hidden_size)`
  47. `hidden_size` must match the hidden size of language model backbone.
  48. """
  49. InternVLImageInputs = Union[InternVLImagePixelInputs,
  50. InternVLImageEmbeddingInputs]
  51. # copied from https://huggingface.co/OpenGVLab/InternVL2-1B
  52. def build_transform(input_size):
  53. MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
  54. transform = T.Compose([
  55. T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
  56. T.Resize((input_size, input_size),
  57. interpolation=T.InterpolationMode.BICUBIC),
  58. T.ToTensor(),
  59. T.Normalize(mean=MEAN, std=STD)
  60. ])
  61. return transform
  62. # copied from https://huggingface.co/OpenGVLab/InternVL2-1B
  63. def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
  64. image_size):
  65. best_ratio_diff = float('inf')
  66. best_ratio = (1, 1)
  67. area = width * height
  68. for ratio in target_ratios:
  69. target_aspect_ratio = ratio[0] / ratio[1]
  70. ratio_diff = abs(aspect_ratio - target_aspect_ratio)
  71. if ratio_diff < best_ratio_diff:
  72. best_ratio_diff = ratio_diff
  73. best_ratio = ratio
  74. elif ratio_diff == best_ratio_diff:
  75. if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
  76. best_ratio = ratio
  77. return best_ratio
  78. def calculate_num_blocks(orig_width: int, orig_height: int,
  79. min_num: int, max_num: int,
  80. image_size: int) -> Tuple[int, int, int]:
  81. aspect_ratio = orig_width / orig_height
  82. # calculate the existing image aspect ratio
  83. target_ratios = set((i, j) for n in range(min_num, max_num + 1)
  84. for i in range(1, n + 1) for j in range(1, n + 1)
  85. if i * j <= max_num and i * j >= min_num)
  86. target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
  87. # find the closest aspect ratio to the target
  88. target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
  89. target_ratios, orig_width,
  90. orig_height, image_size)
  91. # calculate the target width and height
  92. target_width = image_size * target_aspect_ratio[0]
  93. target_height = image_size * target_aspect_ratio[1]
  94. blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
  95. return blocks, target_width, target_height
  96. # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
  97. def dynamic_preprocess(image: Image.Image, min_num: int,
  98. max_num: int, image_size: int,
  99. use_thumbnail: int) -> List[Image.Image]:
  100. orig_width, orig_height = image.size
  101. blocks, target_width, target_height = calculate_num_blocks(
  102. orig_width, orig_height, min_num, max_num, image_size)
  103. # resize the image
  104. resized_img = image.resize((target_width, target_height))
  105. processed_images = []
  106. for i in range(blocks):
  107. box = ((i % (target_width // image_size)) * image_size,
  108. (i // (target_width // image_size)) * image_size,
  109. ((i % (target_width // image_size)) + 1) * image_size,
  110. ((i // (target_width // image_size)) + 1) * image_size)
  111. # split the image
  112. split_img = resized_img.crop(box)
  113. processed_images.append(split_img)
  114. assert len(processed_images) == blocks
  115. if use_thumbnail and len(processed_images) != 1:
  116. thumbnail_img = image.resize((image_size, image_size))
  117. processed_images.append(thumbnail_img)
  118. return processed_images
  119. # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
  120. def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
  121. max_num: int, use_thumbnail: bool) -> torch.Tensor:
  122. transform = build_transform(input_size=input_size)
  123. images = dynamic_preprocess(image,
  124. min_num=min_num,
  125. max_num=max_num,
  126. image_size=input_size,
  127. use_thumbnail=use_thumbnail)
  128. pixel_values = [transform(image) for image in images]
  129. pixel_values = torch.stack(pixel_values)
  130. return pixel_values
  131. def get_internvl_num_patches(image_size: int, patch_size: int,
  132. downsample_ratio: float):
  133. return int(
  134. get_clip_num_patches(image_size=image_size, patch_size=patch_size) *
  135. (downsample_ratio**2))
  136. def get_max_internvl_image_tokens(ctx: InputContext):
  137. hf_config = ctx.get_hf_config(PretrainedConfig)
  138. vision_config = hf_config.vision_config
  139. use_thumbnail = hf_config.use_thumbnail
  140. max_dynamic_patch = hf_config.max_dynamic_patch
  141. if use_thumbnail:
  142. max_dynamic_patch += 1
  143. downsample_ratio = hf_config.downsample_ratio
  144. image_size = vision_config.image_size
  145. patch_size = vision_config.patch_size
  146. num_patches = get_internvl_num_patches(image_size, patch_size,
  147. downsample_ratio)
  148. return num_patches * max_dynamic_patch
  149. def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
  150. multi_modal_data = llm_inputs.get("multi_modal_data")
  151. if multi_modal_data is None or "image" not in multi_modal_data:
  152. return llm_inputs
  153. model_config = ctx.model_config
  154. hf_config = ctx.get_hf_config(PretrainedConfig)
  155. vision_config = hf_config.vision_config
  156. image_size = vision_config.image_size
  157. patch_size = vision_config.patch_size
  158. downsample_ratio = hf_config.downsample_ratio
  159. num_patches = get_internvl_num_patches(image_size, patch_size,
  160. downsample_ratio)
  161. image_data = multi_modal_data["image"]
  162. if isinstance(image_data, Image.Image):
  163. width, height = image_data.size
  164. min_num = hf_config.min_dynamic_patch
  165. max_num = hf_config.max_dynamic_patch
  166. num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
  167. max_num, image_size)
  168. # add thumbnail image if num_blocks > 1
  169. if hf_config.use_thumbnail and num_blocks > 1:
  170. num_blocks += 1
  171. image_feature_size = num_blocks * num_patches
  172. elif isinstance(image_data, torch.Tensor):
  173. image_feature_size = image_data.shape[0]
  174. else:
  175. raise TypeError(f"Invalid image type: {type(image_data)}")
  176. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  177. trust_remote_code=True)
  178. prompt = llm_inputs.get("prompt")
  179. prompt_token_ids = llm_inputs["prompt_token_ids"]
  180. if prompt is None:
  181. prompt = tokenizer.decode(prompt_token_ids)
  182. image_prompt = IMG_START + IMG_CONTEXT * image_feature_size + IMG_END
  183. new_prompt = prompt.replace('<image>', image_prompt, 1)
  184. new_prompt_token_ids = tokenizer.encode(new_prompt)
  185. return LLMInputs(prompt=prompt,
  186. prompt_token_ids=new_prompt_token_ids,
  187. multi_modal_data=multi_modal_data)
  188. def input_mapper_for_internvl(ctx: InputContext, data: object):
  189. hf_config = ctx.get_hf_config(PretrainedConfig)
  190. use_thumbnail = hf_config.use_thumbnail
  191. min_num = hf_config.min_dynamic_patch
  192. max_num = hf_config.max_dynamic_patch
  193. image_size = hf_config.vision_config.image_size
  194. if isinstance(data, Image.Image):
  195. data = image_to_pixel_values(data,
  196. image_size,
  197. min_num,
  198. max_num,
  199. use_thumbnail=use_thumbnail)
  200. model_config = ctx.model_config
  201. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  202. trust_remote_code=True)
  203. image_token_id = tokenizer.encode(IMG_CONTEXT,
  204. add_special_tokens=False,
  205. return_tensors="pt")[0]
  206. return MultiModalInputs({
  207. "pixel_values": data,
  208. "image_token_id": image_token_id
  209. })
  210. def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
  211. image_feature_size = get_max_internvl_image_tokens(ctx)
  212. model_config = ctx.model_config
  213. hf_config = ctx.get_hf_config(PretrainedConfig)
  214. vision_config = hf_config.vision_config
  215. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  216. trust_remote_code=True)
  217. seq_data = dummy_seq_data_for_clip(
  218. vision_config,
  219. seq_len,
  220. image_token_id=tokenizer.encode(IMG_CONTEXT,
  221. add_special_tokens=False)[0],
  222. image_feature_size_override=image_feature_size,
  223. )
  224. image_size = vision_config.image_size
  225. min_num = hf_config.min_dynamic_patch
  226. max_num = hf_config.max_dynamic_patch
  227. max_image_width = max_num * image_size
  228. max_image_height = min_num * image_size
  229. mm_data = dummy_image_for_clip(
  230. vision_config,
  231. image_width_override=max_image_width,
  232. image_height_override=max_image_height,
  233. )
  234. return seq_data, mm_data
  235. @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_internvl)
  236. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
  237. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
  238. @INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
  239. class InternVLChatModel(nn.Module, SupportsMultiModal):
  240. def __init__(self,
  241. config: PretrainedConfig,
  242. multimodal_config: MultiModalConfig,
  243. cache_config: Optional[CacheConfig] = None,
  244. quant_config: Optional[QuantizationConfig] = None) -> None:
  245. super().__init__()
  246. self.config = config
  247. self.multimodal_config = multimodal_config
  248. image_size = config.force_image_size or config.vision_config.image_size
  249. patch_size = config.vision_config.patch_size
  250. self.patch_size = patch_size
  251. self.select_layer = config.select_layer
  252. self.num_image_token = int(
  253. (image_size // patch_size)**2 * (config.downsample_ratio**2))
  254. self.downsample_ratio = config.downsample_ratio
  255. self.ps_version = config.ps_version
  256. vision_feature_layer = self.select_layer
  257. if vision_feature_layer < 0:
  258. num_hidden_layers = config.vision_config.num_hidden_layers \
  259. + vision_feature_layer + 1
  260. else:
  261. num_hidden_layers = vision_feature_layer + 1
  262. self.vision_model = InternVisionModel(
  263. config.vision_config, num_hidden_layers_override=num_hidden_layers)
  264. self.language_model = init_aphrodite_registered_model(
  265. config.text_config, cache_config, quant_config)
  266. vit_hidden_size = config.vision_config.hidden_size
  267. llm_hidden_size = config.text_config.hidden_size
  268. self.mlp1 = nn.Sequential(
  269. nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
  270. nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
  271. llm_hidden_size), nn.GELU(),
  272. nn.Linear(llm_hidden_size, llm_hidden_size))
  273. self.img_context_token_id = None
  274. def pixel_shuffle(self, x, scale_factor=0.5):
  275. n, w, h, c = x.size()
  276. # N, W, H, C --> N, W, H * scale, C // scale
  277. x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
  278. # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
  279. x = x.permute(0, 2, 1, 3).contiguous()
  280. x = x.view(n, int(h * scale_factor), int(w * scale_factor),
  281. int(c / (scale_factor * scale_factor)))
  282. if self.ps_version == 'v1':
  283. pass
  284. else:
  285. x = x.permute(0, 2, 1, 3).contiguous()
  286. return x
  287. def extract_feature(self, pixel_values):
  288. vit_embeds = self.vision_model(pixel_values=pixel_values)
  289. vit_embeds = vit_embeds[:, 1:, :]
  290. h = w = int(vit_embeds.shape[1]**0.5)
  291. vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
  292. vit_embeds = self.pixel_shuffle(vit_embeds,
  293. scale_factor=self.downsample_ratio)
  294. vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
  295. vit_embeds.shape[-1])
  296. vit_embeds = self.mlp1(vit_embeds)
  297. return vit_embeds
  298. def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
  299. if list(data.shape[1:]) != [2]:
  300. raise ValueError(
  301. f"The expected image sizes shape is batch dimension plus "
  302. f"{[2]}. You supplied {data.shape}.")
  303. return data
  304. def _validate_pixel_values(
  305. self, data: Union[torch.Tensor, List[torch.Tensor]]
  306. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  307. h = w = self.config.vision_config.image_size
  308. expected_dims = (3, h, w)
  309. def _validate_shape(d: torch.Tensor):
  310. actual_dims = tuple(d.shape)
  311. if actual_dims != expected_dims:
  312. expected_expr = ("num_patches", *map(str, expected_dims))
  313. raise ValueError(
  314. "The expected shape of pixel values in each batch element "
  315. f"is {expected_expr}. You supplied {tuple(d.shape)}.")
  316. for d in data:
  317. _validate_shape(d)
  318. return data
  319. def _parse_and_validate_image_input(
  320. self, **kwargs: object) -> Optional[InternVLImageInputs]:
  321. pixel_values = kwargs.pop("pixel_values", None)
  322. image_token_id = kwargs.pop("image_token_id", None)
  323. image_embeds = kwargs.pop("image_embeds", None)
  324. if pixel_values is None and image_embeds is None:
  325. return None
  326. if image_embeds is not None:
  327. if not isinstance(image_embeds, torch.Tensor):
  328. raise ValueError("Incorrect type of image embeddings. "
  329. f"Got type: {type(image_embeds)}")
  330. return InternVLImageEmbeddingInputs(
  331. type="image_embeds",
  332. data=image_embeds,
  333. )
  334. self.img_context_token_id = image_token_id[0]
  335. if pixel_values is not None:
  336. if not isinstance(pixel_values, (torch.Tensor, list)):
  337. raise ValueError("Incorrect type of pixel values. "
  338. f"Got type: {type(pixel_values)}")
  339. return InternVLImagePixelInputs(
  340. type="pixel_values",
  341. data=self._validate_pixel_values(pixel_values),
  342. )
  343. raise AssertionError("This line should be unreachable.")
  344. def _process_image_input(
  345. self,
  346. image_input: InternVLImageInputs,
  347. ) -> torch.Tensor:
  348. if image_input["type"] == "image_embeds":
  349. return image_input["data"]
  350. assert self.vision_model is not None
  351. image_embeds = self.extract_feature(image_input["data"])
  352. return image_embeds
  353. def forward(
  354. self,
  355. input_ids: torch.Tensor,
  356. positions: torch.Tensor,
  357. kv_caches: List[torch.Tensor],
  358. attn_metadata: AttentionMetadata,
  359. intermediate_tensors: Optional[IntermediateTensors] = None,
  360. **kwargs: object,
  361. ) -> SamplerOutput:
  362. image_input = self._parse_and_validate_image_input(**kwargs)
  363. if image_input is not None:
  364. inputs_embeds = self.language_model.model.get_input_embeddings(
  365. input_ids)
  366. vision_embeddings = self._process_image_input(image_input)
  367. inputs_embeds = merge_multimodal_embeddings(
  368. input_ids, inputs_embeds, vision_embeddings,
  369. self.img_context_token_id)
  370. input_ids = None
  371. else:
  372. inputs_embeds = None
  373. hidden_states = self.language_model.model(input_ids,
  374. positions,
  375. kv_caches,
  376. attn_metadata,
  377. None,
  378. inputs_embeds=inputs_embeds)
  379. return hidden_states
  380. def compute_logits(
  381. self,
  382. hidden_states: torch.Tensor,
  383. sampling_metadata: SamplingMetadata,
  384. ) -> Optional[torch.Tensor]:
  385. return self.language_model.compute_logits(hidden_states,
  386. sampling_metadata)
  387. def sample(
  388. self,
  389. logits: torch.Tensor,
  390. sampling_metadata: SamplingMetadata,
  391. ) -> Optional[SamplerOutput]:
  392. return self.language_model.sample(logits, sampling_metadata)
  393. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  394. # prepare weight iterators for components
  395. vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
  396. # load vision encoder
  397. vit_weights = filter_weights(vit_weights, "vision_model")
  398. self.vision_model.load_weights(vit_weights)
  399. # load mlp projector
  400. mlp_weights = filter_weights(mlp_weights, "mlp1")
  401. mlp_params_dict = dict(self.mlp1.named_parameters())
  402. for name, loaded_weight in mlp_weights:
  403. param = mlp_params_dict[name]
  404. weight_loader = getattr(param, "weight_loader",
  405. default_weight_loader)
  406. weight_loader(param, loaded_weight)
  407. # load llm backbone
  408. llm_weights = filter_weights(llm_weights, "language_model")
  409. self.language_model.load_weights(llm_weights)