phi3v.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. # coding=utf-8
  2. # Copyright 2024 The PygmalionAI team.
  3. # Copyright 2024 The vLLM team.
  4. # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict
  18. import numpy as np
  19. import torch
  20. import torch.nn as nn
  21. from loguru import logger
  22. from PIL import Image
  23. from transformers import CLIPVisionConfig, PretrainedConfig
  24. from aphrodite.attention import AttentionMetadata
  25. from aphrodite.common.config import CacheConfig, VisionLanguageConfig
  26. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  27. from aphrodite.inputs import INPUT_REGISTRY, InputContext
  28. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  29. from aphrodite.modeling.layers.sampler import Sampler
  30. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  31. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  32. from aphrodite.modeling.models.clip import CLIPVisionModel
  33. from aphrodite.modeling.models.llama import LlamaModel
  34. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  35. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  36. from aphrodite.quantization.base_config import QuantizationConfig
  37. from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
  38. from .interfaces import SupportsVision
  39. _KEYS_TO_MODIFY_MAPPING = {
  40. "model.vision_embed_tokens": "vision_embed_tokens",
  41. }
  42. CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
  43. hidden_act="quick_gelu",
  44. hidden_size=1024,
  45. image_size=336,
  46. intermediate_size=4096,
  47. num_attention_heads=16,
  48. num_channels=3,
  49. num_hidden_layers=24,
  50. patch_size=14,
  51. projection_dim=768)
  52. class Phi3ImageEmbeddingBase(nn.Module):
  53. def __init__(self, wte=None) -> None:
  54. super().__init__()
  55. self.wte = wte
  56. self.layer_idx: int
  57. self.type_feature: str
  58. self.img_processor: CLIPVisionModel
  59. def get_img_features(self,
  60. img_embeds: torch.FloatTensor) -> torch.FloatTensor:
  61. LAYER_IDX = self.layer_idx
  62. TYPE_FEATURE = self.type_feature
  63. # NOTE: we skip the step to select the vision feature layer since
  64. # this is already done inside the img_processor
  65. img_feature = self.img_processor(img_embeds,
  66. vision_feature_layer=LAYER_IDX)
  67. if TYPE_FEATURE == "patch":
  68. patch_feature = img_feature[:, 1:]
  69. return patch_feature
  70. if TYPE_FEATURE == "cls_patch":
  71. return img_feature
  72. raise NotImplementedError
  73. # adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
  74. class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
  75. """Phi3 Image embedding with HD transform."""
  76. def __init__(self,
  77. vision_language_config: VisionLanguageConfig,
  78. config: PretrainedConfig,
  79. wte=None) -> None:
  80. super().__init__(wte)
  81. self.image_token_id = vision_language_config.image_token_id
  82. # n_embed or hidden_size
  83. hidden_size = config.n_embd if hasattr(
  84. config, 'n_embd') else config.hidden_size
  85. clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
  86. self.img_processor = CLIPVisionModel(clip_config)
  87. image_dim_out = config.img_processor['image_dim_out']
  88. self.num_img_tokens = config.img_processor['num_img_tokens']
  89. self.image_dim_out = image_dim_out
  90. # global_gn and sub_gn for hd transform, serves as line separator
  91. self.use_hd_transform = config.embd_layer.get('use_hd_transform',
  92. False)
  93. self.with_learnable_separator = config.embd_layer.get(
  94. 'with_learnable_separator', False)
  95. self.hd_transform_order = config.embd_layer.get(
  96. 'hd_transform_order', 'glb_sub')
  97. # with_hd_transform and with_learnable_separator should have same value
  98. assert self.use_hd_transform and self.with_learnable_separator
  99. # 1024 * 4, merge spatial to channel dimension
  100. self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4]))
  101. self.sub_GN = nn.Parameter(
  102. torch.empty([1, 1, 1, self.image_dim_out * 4]))
  103. dim_projection = hidden_size
  104. depth = 2
  105. layers = [nn.Linear(image_dim_out * 4, dim_projection)]
  106. for _ in range(1, depth):
  107. layers.extend(
  108. [nn.GELU(),
  109. nn.Linear(dim_projection, dim_projection)])
  110. self.img_projection = nn.Sequential(*layers)
  111. self.vocab_size = config.vocab_size
  112. self.layer_idx = config.img_processor.get('layer_idx', -2)
  113. self.type_feature = config.img_processor.get('type_feature', 'patch')
  114. def forward(self, input_ids: torch.LongTensor,
  115. pixel_values: torch.FloatTensor,
  116. image_sizes: torch.Tensor) -> torch.FloatTensor:
  117. """process and merge text embeddings with image embeddings."""
  118. # (batch_size, max_num_crops, 3, height, width)
  119. img_embeds = pixel_values
  120. # (batch_size, 2)
  121. img_sizes = image_sizes
  122. input_shape = input_ids.size()
  123. input_ids = input_ids.view(-1, input_shape[-1])
  124. positions = torch.nonzero(input_ids == self.image_token_id)
  125. select = False
  126. target_dtype = self.img_projection[0].bias.dtype
  127. if len(positions.tolist()) > 0:
  128. # if self.use_hd_transform and img_sizes:
  129. # img_embeds: (num_images, max_num_crops, 3, H, W)
  130. # img_sizes: (num_images, 2).view(1, -1)
  131. bs = img_embeds.shape[0]
  132. # Nx(HW)xC
  133. img_features = self.get_img_features(img_embeds.flatten(0, 1))
  134. base_feat_height = base_feat_width = int(
  135. img_features.shape[1]**0.5)
  136. # bs x max_num_crops x (24x24) x C
  137. img_features = img_features.view(
  138. bs, -1, base_feat_height * base_feat_width, self.image_dim_out)
  139. C = self.image_dim_out
  140. H = base_feat_height
  141. output_imgs = []
  142. output_len = []
  143. for _bs in range(bs):
  144. h, w = img_sizes[_bs]
  145. h = h // 336
  146. w = w // 336
  147. B_ = h * w
  148. # 1 x (24x24) x 1024
  149. global_img_feature = img_features[_bs, :1]
  150. # 1 x 12 x 12 x 4096
  151. glb_img = global_img_feature \
  152. .reshape(1, H // 2, 2, H // 2, 2,C) \
  153. .permute(0, 1, 3, 2, 4, 5) \
  154. .reshape(1, H // 2, H // 2, 4 * C)
  155. temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1)
  156. # 1 x 156 x 4096
  157. glb_img = torch.cat([glb_img, temp_glb_GN],
  158. dim=2).reshape(1, -1, 4 * C)
  159. # (max_num_crops-1) x (12x12) x C
  160. sub_img = img_features[_bs, 1:]
  161. # 16x574x1024
  162. # get rid of padding sub_img
  163. sub_img = sub_img[:B_]
  164. sub_img = sub_img.reshape(B_, H // 2, 2, H // 2, 2, C) \
  165. .permute(0, 1, 3, 2, 4, 5).reshape(B_, -1, 4 * C)
  166. sub_img = sub_img.reshape(1, h, w, 12, 12, -1) \
  167. .permute(0, 1, 3, 2, 4, 5) \
  168. .reshape(1, h * 12, w * 12, 4 * C)
  169. temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1)
  170. sub_img = torch.cat([sub_img, temp_sub_GN],
  171. dim=2).reshape(1, -1, 4 * C)
  172. # (1, num_img_tokens, 1024*4)
  173. # glb + sub
  174. if self.hd_transform_order == 'glb_sub':
  175. output_imgs.append(
  176. torch.cat([glb_img, self.glb_GN, sub_img], dim=1))
  177. elif self.hd_transform_order == 'sub_glb':
  178. output_imgs.append(
  179. torch.cat([sub_img, self.glb_GN, glb_img], dim=1))
  180. temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12)
  181. output_len.append(temp_len)
  182. num_img_tokens = output_len
  183. img_set_tensor = []
  184. for _output_img in output_imgs:
  185. img_feature_proj = self.img_projection(
  186. _output_img.to(target_dtype))
  187. img_set_tensor.append(img_feature_proj)
  188. select = True
  189. input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
  190. hidden_states = self.wte(input_ids)
  191. if select:
  192. idx = 0
  193. for i, cnt in enumerate(num_img_tokens):
  194. hidden_states[positions[idx, 0],
  195. positions[idx, 1]:positions[idx, 1] +
  196. cnt] = (img_set_tensor[i].to(
  197. hidden_states.dtype))
  198. idx += cnt
  199. return hidden_states.squeeze(0)
  200. class Phi3VImagePixelInputs(TypedDict):
  201. type: Literal["pixel_values"]
  202. data: torch.Tensor
  203. """Shape: (batch_size, 1 + num_patches, num_channels, height, width)"""
  204. image_sizes: torch.Tensor
  205. """Shape: (batch_size, 2)"""
  206. def _get_phi3v_image_feature_size(
  207. *,
  208. input_height: int,
  209. input_width: int,
  210. ) -> int:
  211. h, w = input_height, input_width
  212. # https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L178
  213. return (h // 336 * w // 336 + 1) * 144 + 1 + (h // 336 + 1) * 12
  214. def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
  215. multimodal_config = ctx.get_multimodal_config()
  216. #TODO: change the logic for dummy data to support dynamic shape
  217. _, _, dummy_height, dummy_width = multimodal_config.image_input_shape
  218. image_feature_size = _get_phi3v_image_feature_size(
  219. input_height=dummy_height,
  220. input_width=dummy_width,
  221. )
  222. seq_data = dummy_seq_data_for_clip(
  223. CLIP_VIT_LARGE_PATCH14_336_CONFIG,
  224. seq_len,
  225. image_token_id=32044,
  226. image_feature_size_override=image_feature_size,
  227. )
  228. mm_data = dummy_image_for_clip(
  229. CLIP_VIT_LARGE_PATCH14_336_CONFIG,
  230. image_width_override=dummy_width,
  231. image_height_override=dummy_height,
  232. )
  233. return seq_data, mm_data
  234. # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
  235. def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
  236. target_height = int(np.ceil(height / padding_unit) * padding_unit)
  237. top_padding = int((target_height - height) / 2)
  238. bottom_padding = target_height - height - top_padding
  239. padded_width = width
  240. padded_height = height + top_padding + bottom_padding
  241. return padded_width, padded_height
  242. # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
  243. def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
  244. transposed = False
  245. if width < height:
  246. width, height = height, width
  247. transposed = True
  248. ratio = width / height
  249. scale = 1
  250. while scale * np.ceil(scale / ratio) <= hd_num:
  251. scale += 1
  252. scale -= 1
  253. new_width = int(scale * 336)
  254. new_height = int(new_width / ratio)
  255. padded_width, padded_height = _calc_padded_size(width=new_width,
  256. height=new_height)
  257. if transposed:
  258. padded_width, padded_height = padded_height, padded_width
  259. return padded_width, padded_height
  260. def _image_processor(ctx: InputContext,
  261. image: object) -> Dict[str, torch.Tensor]:
  262. if isinstance(image, Image.Image):
  263. # Temporary patch before dynamic number of image tokens is supported
  264. _, _, h, w = ctx.get_multimodal_config().image_input_shape
  265. if (w, h) != _calc_hd_transform_size(width=image.width,
  266. height=image.height):
  267. logger.warning("Dynamic image shape is currently not supported. "
  268. f"Resizing input image to ({w}, {h}).")
  269. image = image.resize((w, h))
  270. return MULTIMODAL_REGISTRY._get_plugin("image") \
  271. ._default_input_mapper(ctx, image)
  272. raise TypeError(f"Invalid type for 'image': {type(image)}")
  273. @MULTIMODAL_REGISTRY.register_image_input_mapper(_image_processor)
  274. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
  275. class Phi3VForCausalLM(nn.Module, SupportsVision):
  276. def __init__(self,
  277. config: PretrainedConfig,
  278. vlm_config: VisionLanguageConfig,
  279. cache_config: Optional[CacheConfig] = None,
  280. quant_config: Optional[QuantizationConfig] = None) -> None:
  281. super().__init__()
  282. self.config = config
  283. self.vlm_config = vlm_config
  284. self.model = LlamaModel(config, cache_config, quant_config)
  285. self.vision_embed_tokens = Phi3HDImageEmbedding(
  286. vlm_config, config, self.model.embed_tokens)
  287. self.lm_head = ParallelLMHead(config.vocab_size,
  288. config.hidden_size,
  289. quant_config=quant_config)
  290. self.logits_processor = LogitsProcessor(config.vocab_size)
  291. self.sampler = Sampler()
  292. def _parse_and_validate_image_input(
  293. self, **kwargs: object) -> Optional[Phi3VImagePixelInputs]:
  294. pixel_values = kwargs.pop("pixel_values", None)
  295. image_sizes = kwargs.pop("image_sizes", None)
  296. if pixel_values is not None and image_sizes is not None:
  297. return Phi3VImagePixelInputs(type="pixel_values",
  298. data=pixel_values,
  299. image_sizes=image_sizes)
  300. return None
  301. def forward(self,
  302. input_ids: torch.Tensor,
  303. positions: torch.Tensor,
  304. kv_caches: List[torch.Tensor],
  305. attn_metadata: AttentionMetadata,
  306. intermediate_tensors: Optional[IntermediateTensors] = None,
  307. **kwargs: object):
  308. image_input = self._parse_and_validate_image_input(**kwargs)
  309. if image_input is not None:
  310. inputs_embeds = self.vision_embed_tokens(
  311. input_ids, image_input["data"], image_input["image_sizes"])
  312. input_ids = None
  313. else:
  314. inputs_embeds = None
  315. hidden_states = self.model(input_ids,
  316. positions,
  317. kv_caches,
  318. attn_metadata,
  319. intermediate_tensors,
  320. inputs_embeds=inputs_embeds)
  321. return hidden_states
  322. def compute_logits(self, hidden_states: torch.Tensor,
  323. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  324. logits = self.logits_processor(self.lm_head, hidden_states,
  325. sampling_metadata)
  326. return logits
  327. def sample(
  328. self,
  329. logits: torch.Tensor,
  330. sampling_metadata: SamplingMetadata,
  331. ) -> Optional[SamplerOutput]:
  332. next_tokens = self.sampler(logits, sampling_metadata)
  333. return next_tokens
  334. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  335. stacked_params_mapping = [
  336. # (param_name, shard_name, shard_id)
  337. (".qkv_proj", ".q_proj", "q"),
  338. (".qkv_proj", ".k_proj", "k"),
  339. (".qkv_proj", ".v_proj", "v"),
  340. (".gate_up_proj", ".gate_proj", 0),
  341. (".gate_up_proj", ".up_proj", 1),
  342. ]
  343. params_dict = dict(self.named_parameters())
  344. for name, loaded_weight in weights:
  345. if "rotary_emb.inv_freq" in name:
  346. continue
  347. # post_layernorm is not needed in CLIPVisionModel
  348. if "vision_model.post_layernorm" in name:
  349. continue
  350. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  351. if key_to_modify in name:
  352. name = name.replace(key_to_modify, new_key)
  353. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  354. # We only do sharding for language model
  355. # and not vision model for now.
  356. if "vision_embed_tokens" in name and self.vision_embed_tokens:
  357. continue
  358. if weight_name not in name:
  359. continue
  360. param = params_dict[name.replace(weight_name, param_name)]
  361. weight_loader = param.weight_loader
  362. weight_loader(param, loaded_weight, shard_id)
  363. break
  364. else:
  365. # Skip loading extra bias for GPTQ models.
  366. if name.endswith(".bias") and name not in params_dict:
  367. continue
  368. param = params_dict[name]
  369. weight_loader = getattr(param, "weight_loader",
  370. default_weight_loader)
  371. weight_loader(param, loaded_weight)