phi3v.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695
  1. # coding=utf-8
  2. # Copyright 2024 The PygmalionAI team.
  3. # Copyright 2024 The vLLM team.
  4. # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import itertools
  18. import re
  19. from functools import lru_cache
  20. from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
  21. Tuple, TypedDict, Union)
  22. import numpy as np
  23. import torch
  24. import torch.nn as nn
  25. from loguru import logger
  26. from PIL import Image
  27. from transformers import CLIPVisionConfig, PretrainedConfig
  28. from aphrodite.attention import AttentionMetadata
  29. from aphrodite.common.config import CacheConfig, ModelConfig, MultiModalConfig
  30. from aphrodite.common.sequence import IntermediateTensors
  31. from aphrodite.common.utils import is_list_of
  32. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  33. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  34. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  35. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  36. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  37. from aphrodite.modeling.models.clip import CLIPVisionModel
  38. from aphrodite.modeling.models.llama import LlamaModel
  39. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  40. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  41. from aphrodite.multimodal.utils import (cached_get_tokenizer,
  42. repeat_and_pad_token)
  43. from aphrodite.quantization.base_config import QuantizationConfig
  44. from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
  45. from .interfaces import SupportsMultiModal
  46. from .utils import flatten_bn, merge_multimodal_embeddings
  47. _KEYS_TO_MODIFY_MAPPING = {
  48. "model.vision_embed_tokens": "vision_embed_tokens",
  49. }
  50. # Cannot find the following 2 numbers from hf config.
  51. _IMAGE_TOKEN_ID = 32044
  52. # Result in the max possible feature size (h:w = 16:1)
  53. MAX_IMAGE_FEATURE_SIZE_HEIGHT = 8000
  54. MAX_IMAGE_FEATURE_SIZE_WIDTH = 50
  55. CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
  56. hidden_act="quick_gelu",
  57. hidden_size=1024,
  58. image_size=336,
  59. intermediate_size=4096,
  60. num_attention_heads=16,
  61. num_channels=3,
  62. num_hidden_layers=24,
  63. patch_size=14,
  64. projection_dim=768)
  65. def _init_img_processor(hf_config: PretrainedConfig):
  66. clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
  67. layer_idx = hf_config.img_processor.get('layer_idx', -2)
  68. # Initialize the CLIP only up to the required feature layer
  69. if layer_idx < 0:
  70. num_hidden_layers = clip_config.num_hidden_layers + \
  71. layer_idx + 1
  72. else:
  73. num_hidden_layers = layer_idx + 1
  74. img_processor = CLIPVisionModel(
  75. clip_config, num_hidden_layers_override=num_hidden_layers)
  76. return img_processor
  77. class Phi3VImagePixelInputs(TypedDict):
  78. type: Literal["pixel_values"]
  79. data: Union[torch.Tensor, List[torch.Tensor]]
  80. """
  81. Shape:
  82. `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
  83. Note that `num_patches` may be different per batch and image,
  84. in which case the data is passed as a list instead of a batched tensor.
  85. """
  86. image_sizes: torch.Tensor
  87. """
  88. Shape: `(batch_size * num_images, 2)`
  89. This should be in `(height, width)` format.
  90. """
  91. class Phi3VImageEmbeddingInputs(TypedDict):
  92. type: Literal["image_embeds"]
  93. data: Union[torch.Tensor, List[torch.Tensor]]
  94. """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
  95. `hidden_size` must match the hidden size of language model backbone.
  96. """
  97. Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs]
  98. class Phi3ImageEmbeddingBase(nn.Module):
  99. def __init__(self) -> None:
  100. super().__init__()
  101. self.layer_idx: int
  102. self.type_feature: str
  103. self.img_processor: CLIPVisionModel
  104. def get_img_features(self,
  105. img_embeds: torch.FloatTensor) -> torch.FloatTensor:
  106. TYPE_FEATURE = self.type_feature
  107. # NOTE: we skip the step to select the vision feature layer since
  108. # this is already done inside the img_processor
  109. img_feature = self.img_processor(img_embeds)
  110. if TYPE_FEATURE == "patch":
  111. patch_feature = img_feature[:, 1:]
  112. return patch_feature
  113. if TYPE_FEATURE == "cls_patch":
  114. return img_feature
  115. raise NotImplementedError
  116. # adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
  117. class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
  118. """Phi3 Image embedding with HD transform."""
  119. def __init__(self, config: PretrainedConfig) -> None:
  120. super().__init__()
  121. # n_embed or hidden_size
  122. hidden_size = config.n_embd if hasattr(
  123. config, 'n_embd') else config.hidden_size
  124. self.img_processor = _init_img_processor(config)
  125. image_dim_out = config.img_processor['image_dim_out']
  126. self.num_img_tokens = config.img_processor['num_img_tokens']
  127. self.image_dim_out = image_dim_out
  128. # global_gn and sub_gn for hd transform, serves as line separator
  129. self.use_hd_transform = config.embd_layer.get('use_hd_transform',
  130. False)
  131. self.with_learnable_separator = config.embd_layer.get(
  132. 'with_learnable_separator', False)
  133. self.hd_transform_order = config.embd_layer.get(
  134. 'hd_transform_order', 'glb_sub')
  135. # with_hd_transform and with_learnable_separator should have same value
  136. assert self.use_hd_transform and self.with_learnable_separator
  137. # 1024 * 4, merge spatial to channel dimension
  138. self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4]))
  139. self.sub_GN = nn.Parameter(
  140. torch.empty([1, 1, 1, self.image_dim_out * 4]))
  141. dim_projection = hidden_size
  142. depth = 2
  143. layers = [nn.Linear(image_dim_out * 4, dim_projection)]
  144. for _ in range(1, depth):
  145. layers.extend(
  146. [nn.GELU(),
  147. nn.Linear(dim_projection, dim_projection)])
  148. self.img_projection = nn.Sequential(*layers)
  149. self.type_feature = config.img_processor.get('type_feature', 'patch')
  150. def forward(self, pixel_values: torch.FloatTensor,
  151. image_sizes: torch.Tensor) -> torch.FloatTensor:
  152. """
  153. process image and return vision embeddings.
  154. pixel_values: (num_images, num_crops, c, h, w)
  155. output: (num_images, num_img_tokens, hidden_size)
  156. """
  157. num_images, num_crops, c, h, w = pixel_values.shape
  158. pixel_values = pixel_values.flatten(0, 1)
  159. img_features = self.get_img_features(pixel_values)
  160. img_features = img_features.reshape(num_images, num_crops, -1,
  161. self.image_dim_out)
  162. image_features_proj = self.hd_feature_transform(
  163. img_features, image_sizes)
  164. return image_features_proj
  165. def hd_feature_transform(self, image_features, image_sizes):
  166. """
  167. image_features: (num_images, num_crops+1, 24*24, 1024)
  168. """
  169. assert (
  170. self.hd_transform_order == 'sub_glb'
  171. ), f'hd_transform_order `{self.hd_transform_order}` not implemented'
  172. if isinstance(self.img_projection, nn.Sequential):
  173. target_device = self.img_projection[0].bias.device
  174. target_dtype = self.img_projection[0].bias.dtype
  175. else: # It's a single nn.Linear layer
  176. target_device = self.img_projection.bias.device
  177. target_dtype = self.img_projection.bias.dtype
  178. global_image_features = image_features[:,
  179. 0] # (num_images, 24*24, 1024)
  180. # global feature can be viewed as a special HD case with num_crops 1x1
  181. global_image_features_hd = self.reshape_hd_patches_2x2merge(
  182. global_image_features, 1, 1)
  183. global_image_features_hd_newline = self.add_image_newline(
  184. global_image_features_hd)
  185. batch_image_features_proj = []
  186. # need a for loop to process each image because of different image sizes
  187. # (patch arrangement is different for each image)
  188. for i, img_size in enumerate(image_sizes):
  189. h, w = img_size
  190. h_crop = h // 336
  191. w_crop = w // 336
  192. num_crops = h_crop * w_crop
  193. # NOTE: real num_crops is padded
  194. # (num_crops, 24*24, 1024)
  195. sub_image_features = image_features[i, 1:1 + num_crops]
  196. sub_image_features_hd = self.reshape_hd_patches_2x2merge(
  197. sub_image_features, h_crop, w_crop)
  198. sub_image_features_hd_newline = self.add_image_newline(
  199. sub_image_features_hd)
  200. # [sub features, separator, global features]
  201. image_embeddings = torch.cat([
  202. sub_image_features_hd_newline.squeeze(
  203. 0), # (h_crop*12*(w_crop*12+1), 4096)
  204. self.glb_GN.squeeze(0),
  205. global_image_features_hd_newline[i],
  206. ])
  207. img_proj = self.img_projection(
  208. image_embeddings.to(target_device, target_dtype))
  209. batch_image_features_proj.append(img_proj)
  210. return batch_image_features_proj
  211. def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
  212. """
  213. image_features: (num_images*num_crops, 24*24, 1024)
  214. output: (num_images, h_crop*12, w_crop*12, 4096)
  215. where h_crop*w_crop == num_crops
  216. """
  217. N, L, C = image_features.shape
  218. assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
  219. num_images = N // (h_crop * w_crop)
  220. H = int(L**0.5)
  221. image_features_hd = (
  222. image_features.reshape(N, H, H, C) # N, 24, 24, 1024
  223. .reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024
  224. .permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024
  225. .reshape(N, -1, 4 * C) # N, 144, 4096
  226. .reshape(num_images, h_crop, w_crop, H // 2, H // 2,
  227. -1) # n_img, h_crop, w_crop, 12, 12, 4096
  228. .permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096
  229. .reshape(num_images, h_crop * H // 2, w_crop * H // 2,
  230. 4 * C) # n_img, h_crop*12, w_crop*12, 4096
  231. )
  232. return image_features_hd
  233. def add_image_newline(self, image_features_hd):
  234. """
  235. image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
  236. output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
  237. """
  238. num_images, h, w, hid_dim = image_features_hd.shape
  239. # add the newline token to the HD image feature patches
  240. newline_embeddings = self.sub_GN.expand(num_images, h, -1,
  241. -1) # (n_img, h, 1, hid_dim)
  242. image_features_hd_newline = torch.cat(
  243. [image_features_hd, newline_embeddings],
  244. dim=2).reshape(num_images, -1, hid_dim)
  245. return image_features_hd_newline
  246. # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
  247. def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
  248. target_height = int(np.ceil(height / padding_unit) * padding_unit)
  249. top_padding = int((target_height - height) / 2)
  250. bottom_padding = target_height - height - top_padding
  251. padded_width = width
  252. padded_height = height + top_padding + bottom_padding
  253. return padded_width, padded_height
  254. # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
  255. def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
  256. transposed = False
  257. if width < height:
  258. width, height = height, width
  259. transposed = True
  260. ratio = width / height
  261. scale = 1
  262. while scale * np.ceil(scale / ratio) <= hd_num:
  263. scale += 1
  264. scale -= 1
  265. new_width = int(scale * 336)
  266. new_height = int(new_width / ratio)
  267. padded_width, padded_height = _calc_padded_size(width=new_width,
  268. height=new_height)
  269. if transposed:
  270. padded_width, padded_height = padded_height, padded_width
  271. return padded_width, padded_height
  272. # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
  273. def get_phi3v_image_feature_size(
  274. hf_config: Dict[str, Any],
  275. *,
  276. input_height: int,
  277. input_width: int,
  278. ) -> int:
  279. num_crops = hf_config.get("num_crops", 16)
  280. new_width, new_height = _calc_hd_transform_size(width=input_width,
  281. height=input_height,
  282. hd_num=num_crops)
  283. return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \
  284. + (new_height // 336 + 1) * 12
  285. def get_max_phi3v_image_tokens(ctx: InputContext):
  286. return get_phi3v_image_feature_size(
  287. ctx.get_hf_image_processor_config(),
  288. input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  289. input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  290. )
  291. def dummy_data_for_phi3v(ctx: InputContext, seq_len: int,
  292. mm_counts: Mapping[str, int]):
  293. num_images = mm_counts["image"]
  294. image_feature_size = get_max_phi3v_image_tokens(ctx)
  295. seq_data = dummy_seq_data_for_clip(
  296. CLIP_VIT_LARGE_PATCH14_336_CONFIG,
  297. seq_len,
  298. num_images,
  299. image_token_id=_IMAGE_TOKEN_ID,
  300. image_feature_size_override=image_feature_size,
  301. )
  302. mm_data = dummy_image_for_clip(
  303. CLIP_VIT_LARGE_PATCH14_336_CONFIG,
  304. num_images,
  305. image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  306. image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  307. )
  308. return seq_data, mm_data
  309. # Reserve this function to also handle placeholders for additional images
  310. # [ref: PR #5820]
  311. @lru_cache
  312. def _get_image_placeholder_token_ids(model_config: ModelConfig,
  313. idx: int) -> List[int]:
  314. assert idx > 0
  315. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  316. # We need to get the token for "<", not "▁<"
  317. # https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
  318. a_token_id, = tokenizer.encode("a", add_special_tokens=False)
  319. a_token_id_, *image_placeholder_token_ids = tokenizer.encode(
  320. f"a<|image_{idx}|>", add_special_tokens=False)
  321. assert a_token_id == a_token_id_
  322. return image_placeholder_token_ids
  323. def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
  324. multi_modal_data = llm_inputs.get("multi_modal_data")
  325. if multi_modal_data is None or "image" not in multi_modal_data:
  326. return llm_inputs
  327. model_config = ctx.model_config
  328. hf_config = ctx.get_hf_image_processor_config()
  329. image_data = multi_modal_data["image"]
  330. if isinstance(image_data, Image.Image):
  331. w, h = image_data.size
  332. image_feature_size = [
  333. get_phi3v_image_feature_size(hf_config,
  334. input_width=w,
  335. input_height=h)
  336. ]
  337. image_data = [image_data]
  338. elif is_list_of(image_data, Image.Image):
  339. image_feature_size = []
  340. for image in image_data:
  341. w, h = image.size
  342. image_feature_size.append(
  343. get_phi3v_image_feature_size(hf_config,
  344. input_width=w,
  345. input_height=h))
  346. elif isinstance(image_data, torch.Tensor):
  347. num_images, image_feature_size, hidden_size = image_data.shape
  348. elif is_list_of(image_data, torch.Tensor):
  349. image_feature_size = [item.shape[1] for item in image_data]
  350. else:
  351. raise TypeError(f"Invalid image type: {type(image_data)}")
  352. prompt = llm_inputs.get("prompt")
  353. if prompt is None:
  354. # for async server request, we assume prompt and its token_ids is always
  355. # in correct format. And num_image_tags == len(image_data) always True.
  356. image_idx = range(1, len(image_data) + 1)
  357. new_prompt = None
  358. else:
  359. image_idx = sorted(map(int, re.findall(r"<\|image_(\d+)\|>+", prompt)))
  360. if prompt.count("<|image|>") > 0:
  361. logger.warning("Please follow the prompt format that is "
  362. "documented on HuggingFace which does not involve "
  363. "repeating <|image|> tokens.")
  364. elif (num_image_tags := len(image_idx)) > 1:
  365. assert num_image_tags == len(
  366. image_data), "The count of image_placeholder not match image's"
  367. new_prompt = prompt
  368. prompt_token_ids = llm_inputs["prompt_token_ids"].copy()
  369. # masked place_holder with image token id
  370. for idx in image_idx:
  371. image_token_ids = _get_image_placeholder_token_ids(model_config,
  372. idx=idx)
  373. for i in range(len(prompt_token_ids) - len(image_token_ids) + 1):
  374. if prompt_token_ids[i:i + len(image_token_ids)] == image_token_ids:
  375. prompt_token_ids[i:i + len(image_token_ids)] = [
  376. _IMAGE_TOKEN_ID
  377. ] * len(image_token_ids)
  378. break
  379. # merge consecutive tag ids
  380. merged_token_ids: List[int] = []
  381. for is_placeholder, token_ids in itertools.groupby(
  382. prompt_token_ids, lambda x: x == _IMAGE_TOKEN_ID):
  383. if is_placeholder:
  384. merged_token_ids.append(_IMAGE_TOKEN_ID)
  385. else:
  386. merged_token_ids.extend(list(token_ids))
  387. # TODO: Move this to utils or integrate with clip.
  388. new_token_ids: List[int] = []
  389. placeholder_idx = 0
  390. while merged_token_ids:
  391. token_id = merged_token_ids.pop(0)
  392. if token_id == _IMAGE_TOKEN_ID:
  393. new_token_ids.extend(
  394. repeat_and_pad_token(
  395. _IMAGE_TOKEN_ID,
  396. repeat_count=image_feature_size[placeholder_idx],
  397. ))
  398. placeholder_idx += 1
  399. else:
  400. new_token_ids.append(token_id)
  401. # NOTE: Create a defensive copy of the original inputs
  402. llm_inputs = LLMInputs(prompt_token_ids=new_token_ids,
  403. prompt=new_prompt,
  404. multi_modal_data=multi_modal_data)
  405. return llm_inputs
  406. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  407. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
  408. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
  409. @INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
  410. class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
  411. def __init__(self,
  412. config: PretrainedConfig,
  413. multimodal_config: MultiModalConfig,
  414. cache_config: Optional[CacheConfig] = None,
  415. quant_config: Optional[QuantizationConfig] = None) -> None:
  416. super().__init__()
  417. self.config = config
  418. self.multimodal_config = multimodal_config
  419. self.image_token_id = _IMAGE_TOKEN_ID
  420. self.model = LlamaModel(config, cache_config, quant_config)
  421. # TODO: Optionally initializes this for supporting embeddings.
  422. self.vision_embed_tokens = Phi3HDImageEmbedding(config)
  423. self.lm_head = ParallelLMHead(config.vocab_size,
  424. config.hidden_size,
  425. quant_config=quant_config)
  426. self.logits_processor = LogitsProcessor(config.vocab_size)
  427. self.sampler = Sampler()
  428. def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
  429. expected_dims = (2, )
  430. def _validate_shape(d: torch.Tensor):
  431. actual_dims = tuple(d.shape)
  432. if actual_dims != expected_dims:
  433. expected_expr = str(expected_dims)
  434. raise ValueError(
  435. f"The expected shape of image sizes per image per batch "
  436. f"is {expected_expr}. You supplied {tuple(d.shape)}.")
  437. for d in data:
  438. _validate_shape(d)
  439. return data
  440. def _validate_pixel_values(
  441. self, data: Union[torch.Tensor, List[torch.Tensor]]
  442. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  443. h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
  444. expected_dims = (3, h, w)
  445. def _validate_shape(d: torch.Tensor):
  446. actual_dims = tuple(d.shape[1:])
  447. if actual_dims != expected_dims:
  448. expected_expr = ("num_patches", *map(str, expected_dims))
  449. raise ValueError(
  450. "The expected shape of pixel values per image per batch "
  451. f"is {expected_expr}. You supplied {tuple(d.shape)}.")
  452. for d in data:
  453. _validate_shape(d)
  454. return data
  455. def _parse_and_validate_image_input(
  456. self, **kwargs: object) -> Optional[Phi3VImageInputs]:
  457. pixel_values = kwargs.pop("pixel_values", None)
  458. image_sizes = kwargs.pop("image_sizes", None)
  459. image_embeds = kwargs.pop("image_embeds", None)
  460. if pixel_values is None:
  461. return None
  462. if pixel_values is None and image_embeds is None:
  463. return None
  464. if pixel_values is not None:
  465. if not isinstance(pixel_values, (torch.Tensor, list)):
  466. raise ValueError("Incorrect type of pixel values. "
  467. f"Got type: {type(pixel_values)}")
  468. if not isinstance(image_sizes, (torch.Tensor, list)):
  469. raise ValueError("Incorrect type of image sizes. "
  470. f"Got type: {type(image_sizes)}")
  471. return Phi3VImagePixelInputs(
  472. type="pixel_values",
  473. data=self._validate_pixel_values(flatten_bn(pixel_values)),
  474. image_sizes=self._validate_image_sizes(
  475. flatten_bn(image_sizes, concat=True)))
  476. if image_embeds is not None:
  477. if not isinstance(image_embeds, torch.Tensor):
  478. raise ValueError("Incorrect type of image embeddings. "
  479. f"Got type: {type(image_embeds)}")
  480. return Phi3VImageEmbeddingInputs(
  481. type="image_embeds",
  482. data=flatten_bn(image_embeds),
  483. )
  484. raise AssertionError("This line should be unreachable.")
  485. def _process_image_input(
  486. self,
  487. image_input: Phi3VImageInputs,
  488. ) -> torch.Tensor:
  489. if image_input["type"] == "image_embeds":
  490. return image_input["data"]
  491. assert self.vision_embed_tokens is not None
  492. image_embeds = self.vision_embed_tokens(image_input["data"],
  493. image_input["image_sizes"])
  494. return image_embeds
  495. def forward(self,
  496. input_ids: torch.Tensor,
  497. positions: torch.Tensor,
  498. kv_caches: List[torch.Tensor],
  499. attn_metadata: AttentionMetadata,
  500. intermediate_tensors: Optional[IntermediateTensors] = None,
  501. **kwargs: object):
  502. image_input = self._parse_and_validate_image_input(**kwargs)
  503. if image_input is not None:
  504. vision_embeddings = self._process_image_input(image_input)
  505. inputs_embeds = self.model.get_input_embeddings(input_ids)
  506. inputs_embeds = merge_multimodal_embeddings(
  507. input_ids, inputs_embeds, vision_embeddings,
  508. self.image_token_id)
  509. input_ids = None
  510. else:
  511. inputs_embeds = None
  512. hidden_states = self.model(input_ids,
  513. positions,
  514. kv_caches,
  515. attn_metadata,
  516. intermediate_tensors,
  517. inputs_embeds=inputs_embeds)
  518. return hidden_states
  519. def compute_logits(
  520. self,
  521. hidden_states: torch.Tensor,
  522. sampling_metadata: SamplingMetadata,
  523. ) -> Optional[torch.Tensor]:
  524. logits = self.logits_processor(self.lm_head, hidden_states,
  525. sampling_metadata)
  526. return logits
  527. def sample(
  528. self,
  529. logits: torch.Tensor,
  530. sampling_metadata: SamplingMetadata,
  531. ) -> Optional[SamplerOutput]:
  532. next_tokens = self.sampler(logits, sampling_metadata)
  533. return next_tokens
  534. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  535. stacked_params_mapping = [
  536. # (param_name, shard_name, shard_id)
  537. (".qkv_proj", ".q_proj", "q"),
  538. (".qkv_proj", ".k_proj", "k"),
  539. (".qkv_proj", ".v_proj", "v"),
  540. (".gate_up_proj", ".gate_proj", 0),
  541. (".gate_up_proj", ".up_proj", 1),
  542. ]
  543. # TODO: This is a temporary fix to load
  544. # the vision weights with CLIPVisionModel.load_weights()
  545. vision_weights = []
  546. params_dict = dict(self.named_parameters())
  547. for name, loaded_weight in weights:
  548. if "rotary_emb.inv_freq" in name:
  549. continue
  550. # Skip loading the img_processor weights since they are
  551. # loaded separately.
  552. if "vision_embed_tokens.img_processor" in name:
  553. vision_weights.append((name, loaded_weight))
  554. continue
  555. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  556. if key_to_modify in name:
  557. name = name.replace(key_to_modify, new_key)
  558. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  559. if weight_name not in name:
  560. continue
  561. param = params_dict[name.replace(weight_name, param_name)]
  562. weight_loader = param.weight_loader
  563. weight_loader(param, loaded_weight, shard_id)
  564. break
  565. else:
  566. # Skip loading extra bias for GPTQ models.
  567. if name.endswith(".bias") and name not in params_dict:
  568. continue
  569. if name in params_dict:
  570. param = params_dict[name]
  571. weight_loader = getattr(param, "weight_loader",
  572. default_weight_loader)
  573. weight_loader(param, loaded_weight)
  574. # We use regex to extract the sub-module name
  575. # from "model.vision_embed_tokens.img_processor.*"
  576. vision_weights = [
  577. (re.search(r"vision_embed_tokens\.img_processor\.(.*)",
  578. n).group(1), w) for n, w in vision_weights
  579. ]
  580. self.vision_embed_tokens.img_processor.load_weights(vision_weights)