internvl.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. # adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
  2. # --------------------------------------------------------
  3. # InternVL
  4. # Copyright (c) 2023 OpenGVLab
  5. # Licensed under The MIT License [see LICENSE for details]
  6. # --------------------------------------------------------
  7. import itertools
  8. from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
  9. import torch
  10. import torch.nn as nn
  11. import torchvision.transforms as T
  12. from PIL import Image
  13. from transformers import PretrainedConfig
  14. from aphrodite.attention import AttentionMetadata
  15. from aphrodite.common.config import CacheConfig, MultiModalConfig
  16. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  17. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  18. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  19. from aphrodite.modeling.models.intern_vit import InternVisionModel
  20. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  21. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  22. from aphrodite.multimodal.base import MultiModalInputs
  23. from aphrodite.multimodal.image import cached_get_tokenizer
  24. from aphrodite.quantization import QuantizationConfig
  25. from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
  26. get_clip_num_patches)
  27. from .interfaces import SupportsVision
  28. from .utils import (filter_weights, init_aphrodite_registered_model,
  29. merge_vision_embeddings)
  30. IMG_START = '<img>'
  31. IMG_END = '</img>'
  32. IMG_CONTEXT = '<IMG_CONTEXT>'
  33. IMAGENET_MEAN = (0.485, 0.456, 0.406)
  34. IMAGENET_STD = (0.229, 0.224, 0.225)
  35. MAX_IMAGE_FEATURE_SIZE_WIDTH = 3000
  36. MAX_IMAGE_FEATURE_SIZE_HEIGHT = 500
  37. class InternVLImagePixelInputs(TypedDict):
  38. type: Literal["pixel_values"]
  39. data: Union[torch.Tensor, List[torch.Tensor]]
  40. """
  41. Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
  42. Note that `num_patches` may be different for each batch, in which case
  43. the data is passed as a list instead of a batched tensor.
  44. """
  45. # copied from https://huggingface.co/OpenGVLab/InternVL2-1B
  46. def build_transform(input_size):
  47. MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
  48. transform = T.Compose([
  49. T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
  50. T.Resize((input_size, input_size),
  51. interpolation=T.InterpolationMode.BICUBIC),
  52. T.ToTensor(),
  53. T.Normalize(mean=MEAN, std=STD)
  54. ])
  55. return transform
  56. # copied from https://huggingface.co/OpenGVLab/InternVL2-1B
  57. def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
  58. image_size):
  59. best_ratio_diff = float('inf')
  60. best_ratio = (1, 1)
  61. area = width * height
  62. for ratio in target_ratios:
  63. target_aspect_ratio = ratio[0] / ratio[1]
  64. ratio_diff = abs(aspect_ratio - target_aspect_ratio)
  65. if ratio_diff < best_ratio_diff:
  66. best_ratio_diff = ratio_diff
  67. best_ratio = ratio
  68. elif ratio_diff == best_ratio_diff:
  69. if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
  70. best_ratio = ratio
  71. return best_ratio
  72. def calculate_num_blocks(orig_width: int,
  73. orig_height: int,
  74. min_num=1,
  75. max_num=6,
  76. image_size=448):
  77. aspect_ratio = orig_width / orig_height
  78. # calculate the existing image aspect ratio
  79. target_ratios = set((i, j) for n in range(min_num, max_num + 1)
  80. for i in range(1, n + 1) for j in range(1, n + 1)
  81. if i * j <= max_num and i * j >= min_num)
  82. target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
  83. # find the closest aspect ratio to the target
  84. target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
  85. target_ratios, orig_width,
  86. orig_height, image_size)
  87. # calculate the target width and height
  88. target_width = image_size * target_aspect_ratio[0]
  89. target_height = image_size * target_aspect_ratio[1]
  90. blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
  91. return blocks, target_width, target_height
  92. # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
  93. def dynamic_preprocess(image,
  94. min_num=1,
  95. max_num=6,
  96. image_size=448,
  97. use_thumbnail=False):
  98. orig_width, orig_height = image.size
  99. blocks, target_width, target_height = calculate_num_blocks(
  100. orig_width, orig_height, min_num, max_num, image_size)
  101. # resize the image
  102. resized_img = image.resize((target_width, target_height))
  103. processed_images = []
  104. for i in range(blocks):
  105. box = ((i % (target_width // image_size)) * image_size,
  106. (i // (target_width // image_size)) * image_size,
  107. ((i % (target_width // image_size)) + 1) * image_size,
  108. ((i // (target_width // image_size)) + 1) * image_size)
  109. # split the image
  110. split_img = resized_img.crop(box)
  111. processed_images.append(split_img)
  112. assert len(processed_images) == blocks
  113. if use_thumbnail and len(processed_images) != 1:
  114. thumbnail_img = image.resize((image_size, image_size))
  115. processed_images.append(thumbnail_img)
  116. return processed_images
  117. # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
  118. def image_to_pixel_values(image: Image.Image, input_size=448, max_num=6):
  119. transform = build_transform(input_size=input_size)
  120. images = dynamic_preprocess(image,
  121. image_size=input_size,
  122. use_thumbnail=True,
  123. max_num=max_num)
  124. pixel_values = [transform(image) for image in images]
  125. pixel_values = torch.stack(pixel_values)
  126. return pixel_values
  127. def get_internvl_num_patches(image_size: int, patch_size: int,
  128. downsample_ratio: float):
  129. return int(
  130. get_clip_num_patches(image_size=image_size, patch_size=patch_size) *
  131. (downsample_ratio**2))
  132. def get_max_internvl_image_tokens(ctx: InputContext):
  133. hf_config = ctx.get_hf_config(PretrainedConfig)
  134. vision_config = hf_config.vision_config
  135. image_size = vision_config.image_size
  136. patch_size = vision_config.patch_size
  137. downsample_ratio = hf_config.downsample_ratio
  138. num_patches = get_internvl_num_patches(image_size, patch_size,
  139. downsample_ratio)
  140. return num_patches * 7
  141. def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
  142. multi_modal_data = llm_inputs.get("multi_modal_data")
  143. if multi_modal_data is None or "image" not in multi_modal_data:
  144. return llm_inputs
  145. model_config = ctx.model_config
  146. hf_config = ctx.get_hf_config(PretrainedConfig)
  147. vision_config = hf_config.vision_config
  148. image_data = multi_modal_data["image"]
  149. if isinstance(image_data, Image.Image):
  150. width, height = image_data.size
  151. num_blocks, _, _ = calculate_num_blocks(width, height)
  152. elif isinstance(image_data, torch.Tensor):
  153. raise NotImplementedError("Embeddings input is not supported yet")
  154. else:
  155. raise TypeError(f"Invalid image type: {type(image_data)}")
  156. image_size = vision_config.image_size
  157. patch_size = vision_config.patch_size
  158. downsample_ratio = hf_config.downsample_ratio
  159. num_patches = get_internvl_num_patches(image_size, patch_size,
  160. downsample_ratio)
  161. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  162. trust_remote_code=True)
  163. prompt = llm_inputs.get("prompt")
  164. prompt_token_ids = llm_inputs["prompt_token_ids"]
  165. if prompt is None:
  166. prompt = tokenizer.decode(prompt_token_ids)
  167. image_prompt = IMG_START + IMG_CONTEXT * (num_blocks +
  168. 1) * num_patches + IMG_END
  169. new_prompt = prompt.replace('<image>', image_prompt, 1)
  170. new_prompt_token_ids = tokenizer.encode(new_prompt)
  171. return LLMInputs(prompt=prompt,
  172. prompt_token_ids=new_prompt_token_ids,
  173. multi_modal_data=multi_modal_data)
  174. def input_mapper_for_internvl(ctx: InputContext, data: object):
  175. if isinstance(data, Image.Image):
  176. data = image_to_pixel_values(data)
  177. model_config = ctx.model_config
  178. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  179. trust_remote_code=True)
  180. image_token_id = tokenizer.encode(IMG_CONTEXT,
  181. add_special_tokens=False,
  182. return_tensors="pt")[0]
  183. return MultiModalInputs({
  184. "pixel_values": data,
  185. "image_token_id": image_token_id
  186. })
  187. def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
  188. image_feature_size = get_max_internvl_image_tokens(ctx)
  189. model_config = ctx.model_config
  190. hf_config = ctx.get_hf_config(PretrainedConfig)
  191. vision_config = hf_config.vision_config
  192. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  193. trust_remote_code=True)
  194. seq_data = dummy_seq_data_for_clip(
  195. vision_config,
  196. seq_len,
  197. image_token_id=tokenizer.encode(IMG_CONTEXT,
  198. add_special_tokens=False)[0],
  199. image_feature_size_override=image_feature_size,
  200. )
  201. mm_data = dummy_image_for_clip(
  202. vision_config,
  203. image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  204. image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  205. )
  206. return seq_data, mm_data
  207. @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_internvl)
  208. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
  209. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
  210. @INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
  211. class InternVLChatModel(nn.Module, SupportsVision):
  212. def __init__(self,
  213. config: PretrainedConfig,
  214. multimodal_config: MultiModalConfig,
  215. cache_config: Optional[CacheConfig] = None,
  216. quant_config: Optional[QuantizationConfig] = None) -> None:
  217. super().__init__()
  218. self.config = config
  219. self.multimodal_config = multimodal_config
  220. image_size = config.force_image_size or config.vision_config.image_size
  221. patch_size = config.vision_config.patch_size
  222. self.patch_size = patch_size
  223. self.select_layer = config.select_layer
  224. self.num_image_token = int(
  225. (image_size // patch_size)**2 * (config.downsample_ratio**2))
  226. self.downsample_ratio = config.downsample_ratio
  227. self.ps_version = config.ps_version
  228. vision_feature_layer = self.select_layer
  229. if vision_feature_layer < 0:
  230. num_hidden_layers = config.vision_config.num_hidden_layers \
  231. + vision_feature_layer + 1
  232. else:
  233. num_hidden_layers = vision_feature_layer + 1
  234. self.vision_model = InternVisionModel(
  235. config.vision_config, num_hidden_layers_override=num_hidden_layers)
  236. self.language_model = init_aphrodite_registered_model(
  237. config.text_config, cache_config, quant_config)
  238. vit_hidden_size = config.vision_config.hidden_size
  239. llm_hidden_size = config.text_config.hidden_size
  240. self.mlp1 = nn.Sequential(
  241. nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
  242. nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
  243. llm_hidden_size), nn.GELU(),
  244. nn.Linear(llm_hidden_size, llm_hidden_size))
  245. self.img_context_token_id = None
  246. def pixel_shuffle(self, x, scale_factor=0.5):
  247. n, w, h, c = x.size()
  248. # N, W, H, C --> N, W, H * scale, C // scale
  249. x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
  250. # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
  251. x = x.permute(0, 2, 1, 3).contiguous()
  252. x = x.view(n, int(h * scale_factor), int(w * scale_factor),
  253. int(c / (scale_factor * scale_factor)))
  254. if self.ps_version == 'v1':
  255. pass
  256. else:
  257. x = x.permute(0, 2, 1, 3).contiguous()
  258. return x
  259. def extract_feature(self, pixel_values):
  260. vit_embeds = self.vision_model(pixel_values=pixel_values)
  261. vit_embeds = vit_embeds[:, 1:, :]
  262. h = w = int(vit_embeds.shape[1]**0.5)
  263. vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
  264. vit_embeds = self.pixel_shuffle(vit_embeds,
  265. scale_factor=self.downsample_ratio)
  266. vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
  267. vit_embeds.shape[-1])
  268. vit_embeds = self.mlp1(vit_embeds)
  269. return vit_embeds
  270. def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
  271. if list(data.shape[1:]) != [2]:
  272. raise ValueError(
  273. f"The expected image sizes shape is batch dimension plus "
  274. f"{[2]}. You supplied {data.shape}.")
  275. return data
  276. def _validate_pixel_values(
  277. self, data: Union[torch.Tensor, List[torch.Tensor]]
  278. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  279. h = w = self.config.vision_config.image_size
  280. expected_dims = (3, h, w)
  281. def _validate_shape(d: torch.Tensor):
  282. actual_dims = tuple(d.shape)
  283. if actual_dims != expected_dims:
  284. expected_expr = ("num_patches", *map(str, expected_dims))
  285. raise ValueError(
  286. "The expected shape of pixel values in each batch element "
  287. f"is {expected_expr}. You supplied {tuple(d.shape)}.")
  288. for d in data:
  289. _validate_shape(d)
  290. return data
  291. def _parse_and_validate_image_input(
  292. self, **kwargs: object) -> Optional[InternVLImagePixelInputs]:
  293. pixel_values = kwargs.pop("pixel_values", None)
  294. image_token_id = kwargs.pop("image_token_id", None)
  295. if pixel_values is None:
  296. return None
  297. self.img_context_token_id = image_token_id[0]
  298. if not isinstance(pixel_values, (torch.Tensor, list)):
  299. raise ValueError("Incorrect type of pixel values. "
  300. f"Got type: {type(pixel_values)}")
  301. return InternVLImagePixelInputs(
  302. type="pixel_values",
  303. data=self._validate_pixel_values(pixel_values),
  304. )
  305. def forward(
  306. self,
  307. input_ids: torch.Tensor,
  308. positions: torch.Tensor,
  309. kv_caches: List[torch.Tensor],
  310. attn_metadata: AttentionMetadata,
  311. intermediate_tensors: Optional[IntermediateTensors] = None,
  312. **kwargs: object,
  313. ) -> SamplerOutput:
  314. image_input = self._parse_and_validate_image_input(**kwargs)
  315. if image_input is not None:
  316. inputs_embeds = self.language_model.model.get_input_embeddings(
  317. input_ids)
  318. vit_embeds = self.extract_feature(image_input["data"])
  319. inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
  320. vit_embeds,
  321. self.img_context_token_id)
  322. input_ids = None
  323. else:
  324. inputs_embeds = None
  325. hidden_states = self.language_model.model(input_ids,
  326. positions,
  327. kv_caches,
  328. attn_metadata,
  329. None,
  330. inputs_embeds=inputs_embeds)
  331. return hidden_states
  332. def compute_logits(self, hidden_states: torch.Tensor,
  333. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  334. return self.language_model.compute_logits(hidden_states,
  335. sampling_metadata)
  336. def sample(
  337. self,
  338. logits: torch.Tensor,
  339. sampling_metadata: SamplingMetadata,
  340. ) -> Optional[SamplerOutput]:
  341. return self.language_model.sample(logits, sampling_metadata)
  342. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  343. # prepare weight iterators for components
  344. vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
  345. # load vision encoder
  346. vit_weights = filter_weights(vit_weights, "vision_model")
  347. self.vision_model.load_weights(vit_weights)
  348. # load mlp projector
  349. mlp_weights = filter_weights(mlp_weights, "mlp1")
  350. mlp_params_dict = dict(self.mlp1.named_parameters())
  351. for name, loaded_weight in mlp_weights:
  352. param = mlp_params_dict[name]
  353. weight_loader = getattr(param, "weight_loader",
  354. default_weight_loader)
  355. weight_loader(param, loaded_weight)
  356. # load llm backbone
  357. llm_weights = filter_weights(llm_weights, "language_model")
  358. self.language_model.load_weights(llm_weights)