internvl.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. # adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
  2. # --------------------------------------------------------
  3. # InternVL
  4. # Copyright (c) 2023 OpenGVLab
  5. # Licensed under The MIT License [see LICENSE for details]
  6. # --------------------------------------------------------
  7. from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
  8. import torch
  9. import torch.nn as nn
  10. import torchvision.transforms as T
  11. from PIL import Image
  12. from transformers import PretrainedConfig
  13. from aphrodite.attention import AttentionMetadata
  14. from aphrodite.common.config import CacheConfig, MultiModalConfig
  15. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  16. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  17. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  18. from aphrodite.modeling.models import ModelRegistry
  19. from aphrodite.modeling.models.intern_vit import InternVisionModel
  20. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  21. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  22. from aphrodite.multimodal.base import MultiModalInputs
  23. from aphrodite.multimodal.image import cached_get_tokenizer
  24. from aphrodite.quantization import QuantizationConfig
  25. from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
  26. get_clip_num_patches)
  27. from .interfaces import SupportsVision
  28. from .utils import merge_vision_embeddings
  29. IMG_START = '<img>'
  30. IMG_END = '</img>'
  31. IMG_CONTEXT = '<IMG_CONTEXT>'
  32. IMAGENET_MEAN = (0.485, 0.456, 0.406)
  33. IMAGENET_STD = (0.229, 0.224, 0.225)
  34. MAX_IMAGE_FEATURE_SIZE_WIDTH = 3000
  35. MAX_IMAGE_FEATURE_SIZE_HEIGHT = 500
  36. class InternVLImagePixelInputs(TypedDict):
  37. type: Literal["pixel_values"]
  38. data: Union[torch.Tensor, List[torch.Tensor]]
  39. """
  40. Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
  41. Note that `num_patches` may be different for each batch, in which case
  42. the data is passed as a list instead of a batched tensor.
  43. """
  44. # copied from https://huggingface.co/OpenGVLab/InternVL2-1B
  45. def build_transform(input_size):
  46. MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
  47. transform = T.Compose([
  48. T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
  49. T.Resize((input_size, input_size),
  50. interpolation=T.InterpolationMode.BICUBIC),
  51. T.ToTensor(),
  52. T.Normalize(mean=MEAN, std=STD)
  53. ])
  54. return transform
  55. # copied from https://huggingface.co/OpenGVLab/InternVL2-1B
  56. def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
  57. image_size):
  58. best_ratio_diff = float('inf')
  59. best_ratio = (1, 1)
  60. area = width * height
  61. for ratio in target_ratios:
  62. target_aspect_ratio = ratio[0] / ratio[1]
  63. ratio_diff = abs(aspect_ratio - target_aspect_ratio)
  64. if ratio_diff < best_ratio_diff:
  65. best_ratio_diff = ratio_diff
  66. best_ratio = ratio
  67. elif ratio_diff == best_ratio_diff:
  68. if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
  69. best_ratio = ratio
  70. return best_ratio
  71. def calculate_num_blocks(orig_width: int,
  72. orig_height: int,
  73. min_num=1,
  74. max_num=6,
  75. image_size=448):
  76. aspect_ratio = orig_width / orig_height
  77. # calculate the existing image aspect ratio
  78. target_ratios = set((i, j) for n in range(min_num, max_num + 1)
  79. for i in range(1, n + 1) for j in range(1, n + 1)
  80. if i * j <= max_num and i * j >= min_num)
  81. target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
  82. # find the closest aspect ratio to the target
  83. target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
  84. target_ratios, orig_width,
  85. orig_height, image_size)
  86. # calculate the target width and height
  87. target_width = image_size * target_aspect_ratio[0]
  88. target_height = image_size * target_aspect_ratio[1]
  89. blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
  90. return blocks, target_width, target_height
  91. # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
  92. def dynamic_preprocess(image,
  93. min_num=1,
  94. max_num=6,
  95. image_size=448,
  96. use_thumbnail=False):
  97. orig_width, orig_height = image.size
  98. blocks, target_width, target_height = calculate_num_blocks(
  99. orig_width, orig_height, min_num, max_num, image_size)
  100. # resize the image
  101. resized_img = image.resize((target_width, target_height))
  102. processed_images = []
  103. for i in range(blocks):
  104. box = ((i % (target_width // image_size)) * image_size,
  105. (i // (target_width // image_size)) * image_size,
  106. ((i % (target_width // image_size)) + 1) * image_size,
  107. ((i // (target_width // image_size)) + 1) * image_size)
  108. # split the image
  109. split_img = resized_img.crop(box)
  110. processed_images.append(split_img)
  111. assert len(processed_images) == blocks
  112. if use_thumbnail and len(processed_images) != 1:
  113. thumbnail_img = image.resize((image_size, image_size))
  114. processed_images.append(thumbnail_img)
  115. return processed_images
  116. # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
  117. def image_to_pixel_values(image: Image.Image, input_size=448, max_num=6):
  118. transform = build_transform(input_size=input_size)
  119. images = dynamic_preprocess(image,
  120. image_size=input_size,
  121. use_thumbnail=True,
  122. max_num=max_num)
  123. pixel_values = [transform(image) for image in images]
  124. pixel_values = torch.stack(pixel_values)
  125. return pixel_values
  126. def get_internvl_num_patches(image_size: int, patch_size: int,
  127. downsample_ratio: float):
  128. return int(
  129. get_clip_num_patches(image_size=image_size, patch_size=patch_size) *
  130. (downsample_ratio**2))
  131. def get_max_internvl_image_tokens(ctx: InputContext):
  132. hf_config = ctx.get_hf_config(PretrainedConfig)
  133. vision_config = hf_config.vision_config
  134. image_size = vision_config.image_size
  135. patch_size = vision_config.patch_size
  136. downsample_ratio = hf_config.downsample_ratio
  137. num_patches = get_internvl_num_patches(image_size, patch_size,
  138. downsample_ratio)
  139. return num_patches * 7
  140. def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
  141. multi_modal_data = llm_inputs.get("multi_modal_data")
  142. if multi_modal_data is None or "image" not in multi_modal_data:
  143. return llm_inputs
  144. model_config = ctx.model_config
  145. hf_config = ctx.get_hf_config(PretrainedConfig)
  146. vision_config = hf_config.vision_config
  147. image_data = multi_modal_data["image"]
  148. if isinstance(image_data, Image.Image):
  149. width, height = image_data.size
  150. num_blocks, _, _ = calculate_num_blocks(width, height)
  151. elif isinstance(image_data, torch.Tensor):
  152. raise NotImplementedError("Embeddings input is not supported yet")
  153. else:
  154. raise TypeError(f"Invalid image type: {type(image_data)}")
  155. image_size = vision_config.image_size
  156. patch_size = vision_config.patch_size
  157. downsample_ratio = hf_config.downsample_ratio
  158. num_patches = get_internvl_num_patches(image_size, patch_size,
  159. downsample_ratio)
  160. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  161. trust_remote_code=True)
  162. prompt = llm_inputs.get("prompt")
  163. prompt_token_ids = llm_inputs["prompt_token_ids"]
  164. if prompt is None:
  165. prompt = tokenizer.decode(prompt_token_ids)
  166. image_prompt = IMG_START + IMG_CONTEXT * (num_blocks +
  167. 1) * num_patches + IMG_END
  168. new_prompt = prompt.replace('<image>', image_prompt, 1)
  169. new_prompt_token_ids = tokenizer.encode(new_prompt)
  170. return LLMInputs(prompt=prompt,
  171. prompt_token_ids=new_prompt_token_ids,
  172. multi_modal_data=multi_modal_data)
  173. def input_mapper_for_internvl(ctx: InputContext, data: object):
  174. if isinstance(data, Image.Image):
  175. data = image_to_pixel_values(data)
  176. model_config = ctx.model_config
  177. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  178. trust_remote_code=True)
  179. image_token_id = tokenizer.encode(IMG_CONTEXT,
  180. add_special_tokens=False,
  181. return_tensors="pt")[0]
  182. return MultiModalInputs({
  183. "pixel_values": data,
  184. "image_token_id": image_token_id
  185. })
  186. def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
  187. image_feature_size = get_max_internvl_image_tokens(ctx)
  188. model_config = ctx.model_config
  189. hf_config = ctx.get_hf_config(PretrainedConfig)
  190. vision_config = hf_config.vision_config
  191. tokenizer = cached_get_tokenizer(model_config.tokenizer,
  192. trust_remote_code=True)
  193. seq_data = dummy_seq_data_for_clip(
  194. vision_config,
  195. seq_len,
  196. image_token_id=tokenizer.encode(IMG_CONTEXT,
  197. add_special_tokens=False)[0],
  198. image_feature_size_override=image_feature_size,
  199. )
  200. mm_data = dummy_image_for_clip(
  201. vision_config,
  202. image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  203. image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  204. )
  205. return seq_data, mm_data
  206. @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_internvl)
  207. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
  208. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
  209. @INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
  210. class InternVLChatModel(nn.Module, SupportsVision):
  211. def __init__(self,
  212. config: PretrainedConfig,
  213. multimodal_config: MultiModalConfig,
  214. cache_config: Optional[CacheConfig] = None,
  215. quant_config: Optional[QuantizationConfig] = None) -> None:
  216. super().__init__()
  217. self.config = config
  218. self.multimodal_config = multimodal_config
  219. image_size = config.force_image_size or config.vision_config.image_size
  220. patch_size = config.vision_config.patch_size
  221. self.patch_size = patch_size
  222. self.select_layer = config.select_layer
  223. self.num_image_token = int(
  224. (image_size // patch_size)**2 * (config.downsample_ratio**2))
  225. self.downsample_ratio = config.downsample_ratio
  226. self.ps_version = config.ps_version
  227. vision_feature_layer = self.select_layer
  228. if vision_feature_layer < 0:
  229. num_hidden_layers = config.vision_config.num_hidden_layers \
  230. + vision_feature_layer + 1
  231. else:
  232. num_hidden_layers = vision_feature_layer + 1
  233. self.vision_model = InternVisionModel(
  234. config.vision_config, num_hidden_layers_override=num_hidden_layers)
  235. llm_class = ModelRegistry.load_model_cls(
  236. config.text_config.architectures[0])
  237. self.language_model = llm_class(config.text_config, cache_config,
  238. quant_config)
  239. vit_hidden_size = config.vision_config.hidden_size
  240. llm_hidden_size = config.text_config.hidden_size
  241. self.mlp1 = nn.Sequential(
  242. nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
  243. nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
  244. llm_hidden_size), nn.GELU(),
  245. nn.Linear(llm_hidden_size, llm_hidden_size))
  246. self.img_context_token_id = None
  247. def pixel_shuffle(self, x, scale_factor=0.5):
  248. n, w, h, c = x.size()
  249. # N, W, H, C --> N, W, H * scale, C // scale
  250. x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
  251. # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
  252. x = x.permute(0, 2, 1, 3).contiguous()
  253. x = x.view(n, int(h * scale_factor), int(w * scale_factor),
  254. int(c / (scale_factor * scale_factor)))
  255. if self.ps_version == 'v1':
  256. pass
  257. else:
  258. x = x.permute(0, 2, 1, 3).contiguous()
  259. return x
  260. def extract_feature(self, pixel_values):
  261. vit_embeds = self.vision_model(pixel_values=pixel_values)
  262. vit_embeds = vit_embeds[:, 1:, :]
  263. h = w = int(vit_embeds.shape[1]**0.5)
  264. vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
  265. vit_embeds = self.pixel_shuffle(vit_embeds,
  266. scale_factor=self.downsample_ratio)
  267. vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
  268. vit_embeds.shape[-1])
  269. vit_embeds = self.mlp1(vit_embeds)
  270. return vit_embeds
  271. def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
  272. if list(data.shape[1:]) != [2]:
  273. raise ValueError(
  274. f"The expected image sizes shape is batch dimension plus "
  275. f"{[2]}. You supplied {data.shape}.")
  276. return data
  277. def _validate_pixel_values(
  278. self, data: Union[torch.Tensor, List[torch.Tensor]]
  279. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  280. h = w = self.config.vision_config.image_size
  281. expected_dims = (3, h, w)
  282. def _validate_shape(d: torch.Tensor):
  283. actual_dims = tuple(d.shape)
  284. if actual_dims != expected_dims:
  285. expected_expr = ("num_patches", *map(str, expected_dims))
  286. raise ValueError(
  287. "The expected shape of pixel values in each batch element "
  288. f"is {expected_expr}. You supplied {tuple(d.shape)}.")
  289. for d in data:
  290. _validate_shape(d)
  291. return data
  292. def _parse_and_validate_image_input(
  293. self, **kwargs: object) -> Optional[InternVLImagePixelInputs]:
  294. pixel_values = kwargs.pop("pixel_values", None)
  295. image_token_id = kwargs.pop("image_token_id", None)
  296. if pixel_values is None:
  297. return None
  298. self.img_context_token_id = image_token_id[0]
  299. if not isinstance(pixel_values, (torch.Tensor, list)):
  300. raise ValueError("Incorrect type of pixel values. "
  301. f"Got type: {type(pixel_values)}")
  302. return InternVLImagePixelInputs(
  303. type="pixel_values",
  304. data=self._validate_pixel_values(pixel_values),
  305. )
  306. def forward(
  307. self,
  308. input_ids: torch.Tensor,
  309. positions: torch.Tensor,
  310. kv_caches: List[torch.Tensor],
  311. attn_metadata: AttentionMetadata,
  312. intermediate_tensors: Optional[IntermediateTensors] = None,
  313. **kwargs: object,
  314. ) -> SamplerOutput:
  315. image_input = self._parse_and_validate_image_input(**kwargs)
  316. if image_input is not None:
  317. inputs_embeds = self.language_model.model.get_input_embeddings(
  318. input_ids)
  319. vit_embeds = self.extract_feature(image_input["data"])
  320. inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
  321. vit_embeds,
  322. self.img_context_token_id)
  323. input_ids = None
  324. else:
  325. inputs_embeds = None
  326. hidden_states = self.language_model.model(input_ids,
  327. positions,
  328. kv_caches,
  329. attn_metadata,
  330. None,
  331. inputs_embeds=inputs_embeds)
  332. return hidden_states
  333. def compute_logits(self, hidden_states: torch.Tensor,
  334. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  335. return self.language_model.compute_logits(hidden_states,
  336. sampling_metadata)
  337. def sample(
  338. self,
  339. logits: torch.Tensor,
  340. sampling_metadata: SamplingMetadata,
  341. ) -> Optional[SamplerOutput]:
  342. return self.language_model.sample(logits, sampling_metadata)
  343. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  344. stacked_params_mapping = [
  345. # (param_name, shard_name, shard_id)
  346. (".qkv_proj", ".q_proj", "q"),
  347. (".qkv_proj", ".k_proj", "k"),
  348. (".qkv_proj", ".v_proj", "v"),
  349. (".gate_up_proj", ".gate_proj", 0),
  350. (".gate_up_proj", ".up_proj", 1),
  351. (".gate_up_proj", ".w1", 0),
  352. (".gate_up_proj", ".w3", 1),
  353. ]
  354. params_dict = dict(self.named_parameters())
  355. for name, loaded_weight in weights:
  356. if "rotary_emb.inv_freq" in name:
  357. continue
  358. if self.config.text_config.tie_word_embeddings \
  359. and "lm_head.weight" in name:
  360. continue
  361. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  362. # We only do sharding for language model
  363. # and not vision model for now.
  364. if "vision_embed_tokens" in name and self.vision_embed_tokens:
  365. continue
  366. if weight_name not in name:
  367. continue
  368. param = params_dict[name.replace(weight_name, param_name)]
  369. weight_loader = param.weight_loader
  370. weight_loader(param, loaded_weight, shard_id)
  371. break
  372. else:
  373. # Skip loading extra bias for GPTQ models.
  374. if name.endswith(".bias") and name not in params_dict:
  375. continue
  376. param = params_dict[name]
  377. if "wqkv" in name:
  378. config = self.config.text_config
  379. kv_groups = (config.num_attention_heads //
  380. config.num_key_value_heads)
  381. head_dim = config.hidden_size // config.num_attention_heads
  382. loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
  383. head_dim,
  384. loaded_weight.shape[-1])
  385. wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
  386. dim=1)
  387. wq = wq.reshape(-1, wq.shape[-1])
  388. wk = wk.reshape(-1, wk.shape[-1])
  389. wv = wv.reshape(-1, wv.shape[-1])
  390. weight_loader = param.weight_loader
  391. weight_loader(param, wq, 'q')
  392. weight_loader(param, wk, 'k')
  393. weight_loader(param, wv, 'v')
  394. continue
  395. weight_loader = getattr(param, "weight_loader",
  396. default_weight_loader)
  397. weight_loader(param, loaded_weight)