1
0

phi3v.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. # coding=utf-8
  2. # Copyright 2024 The PygmalionAI team.
  3. # Copyright 2024 The vLLM team.
  4. # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import re
  18. from functools import lru_cache
  19. from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
  20. import numpy as np
  21. import torch
  22. import torch.nn as nn
  23. from loguru import logger
  24. from PIL import Image
  25. from transformers import CLIPVisionConfig, PretrainedConfig
  26. from aphrodite.attention import AttentionMetadata
  27. from aphrodite.common.config import (CacheConfig, ModelConfig,
  28. VisionLanguageConfig)
  29. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  30. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  31. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  32. from aphrodite.modeling.layers.sampler import Sampler
  33. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  34. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  35. from aphrodite.modeling.models.clip import CLIPVisionModel
  36. from aphrodite.modeling.models.llama import LlamaModel
  37. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  38. from aphrodite.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
  39. from aphrodite.multimodal.image import cached_get_tokenizer
  40. from aphrodite.quantization.base_config import QuantizationConfig
  41. from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
  42. input_processor_for_clip)
  43. from .interfaces import SupportsVision
  44. _KEYS_TO_MODIFY_MAPPING = {
  45. "model.vision_embed_tokens": "vision_embed_tokens",
  46. }
  47. CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
  48. hidden_act="quick_gelu",
  49. hidden_size=1024,
  50. image_size=336,
  51. intermediate_size=4096,
  52. num_attention_heads=16,
  53. num_channels=3,
  54. num_hidden_layers=24,
  55. patch_size=14,
  56. projection_dim=768)
  57. class Phi3ImageEmbeddingBase(nn.Module):
  58. def __init__(self, wte=None) -> None:
  59. super().__init__()
  60. self.wte = wte
  61. self.layer_idx: int
  62. self.type_feature: str
  63. self.img_processor: CLIPVisionModel
  64. def get_img_features(self,
  65. img_embeds: torch.FloatTensor) -> torch.FloatTensor:
  66. LAYER_IDX = self.layer_idx
  67. TYPE_FEATURE = self.type_feature
  68. # NOTE: we skip the step to select the vision feature layer since
  69. # this is already done inside the img_processor
  70. img_feature = self.img_processor(img_embeds,
  71. vision_feature_layer=LAYER_IDX)
  72. if TYPE_FEATURE == "patch":
  73. patch_feature = img_feature[:, 1:]
  74. return patch_feature
  75. if TYPE_FEATURE == "cls_patch":
  76. return img_feature
  77. raise NotImplementedError
  78. # adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
  79. class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
  80. """Phi3 Image embedding with HD transform."""
  81. def __init__(self,
  82. vision_language_config: VisionLanguageConfig,
  83. config: PretrainedConfig,
  84. wte=None) -> None:
  85. super().__init__(wte)
  86. self.image_token_id = vision_language_config.image_token_id
  87. # n_embed or hidden_size
  88. hidden_size = config.n_embd if hasattr(
  89. config, 'n_embd') else config.hidden_size
  90. clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
  91. self.img_processor = CLIPVisionModel(clip_config)
  92. image_dim_out = config.img_processor['image_dim_out']
  93. self.num_img_tokens = config.img_processor['num_img_tokens']
  94. self.image_dim_out = image_dim_out
  95. # global_gn and sub_gn for hd transform, serves as line separator
  96. self.use_hd_transform = config.embd_layer.get('use_hd_transform',
  97. False)
  98. self.with_learnable_separator = config.embd_layer.get(
  99. 'with_learnable_separator', False)
  100. self.hd_transform_order = config.embd_layer.get(
  101. 'hd_transform_order', 'glb_sub')
  102. # with_hd_transform and with_learnable_separator should have same value
  103. assert self.use_hd_transform and self.with_learnable_separator
  104. # 1024 * 4, merge spatial to channel dimension
  105. self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4]))
  106. self.sub_GN = nn.Parameter(
  107. torch.empty([1, 1, 1, self.image_dim_out * 4]))
  108. dim_projection = hidden_size
  109. depth = 2
  110. layers = [nn.Linear(image_dim_out * 4, dim_projection)]
  111. for _ in range(1, depth):
  112. layers.extend(
  113. [nn.GELU(),
  114. nn.Linear(dim_projection, dim_projection)])
  115. self.img_projection = nn.Sequential(*layers)
  116. self.vocab_size = config.vocab_size
  117. self.layer_idx = config.img_processor.get('layer_idx', -2)
  118. self.type_feature = config.img_processor.get('type_feature', 'patch')
  119. def forward(self, input_ids: torch.LongTensor,
  120. pixel_values: torch.FloatTensor,
  121. image_sizes: torch.Tensor) -> torch.FloatTensor:
  122. """process and merge text embeddings with image embeddings."""
  123. # (batch_size, max_num_crops, 3, height, width)
  124. img_embeds = pixel_values
  125. # (batch_size, 2)
  126. img_sizes = image_sizes
  127. input_shape = input_ids.size()
  128. input_ids = input_ids.view(-1, input_shape[-1])
  129. positions = torch.nonzero(input_ids == self.image_token_id)
  130. select = False
  131. target_dtype = self.img_projection[0].bias.dtype
  132. if len(positions.tolist()) > 0:
  133. # if self.use_hd_transform and img_sizes:
  134. # img_embeds: (num_images, max_num_crops, 3, H, W)
  135. # img_sizes: (num_images, 2).view(1, -1)
  136. bs = img_embeds.shape[0]
  137. # Nx(HW)xC
  138. img_features = self.get_img_features(img_embeds.flatten(0, 1))
  139. base_feat_height = base_feat_width = int(
  140. img_features.shape[1]**0.5)
  141. # bs x max_num_crops x (24x24) x C
  142. img_features = img_features.view(
  143. bs, -1, base_feat_height * base_feat_width, self.image_dim_out)
  144. C = self.image_dim_out
  145. H = base_feat_height
  146. output_imgs = []
  147. output_len = []
  148. for _bs in range(bs):
  149. h, w = img_sizes[_bs]
  150. h = h // 336
  151. w = w // 336
  152. B_ = h * w
  153. # 1 x (24x24) x 1024
  154. global_img_feature = img_features[_bs, :1]
  155. # 1 x 12 x 12 x 4096
  156. glb_img = global_img_feature \
  157. .reshape(1, H // 2, 2, H // 2, 2,C) \
  158. .permute(0, 1, 3, 2, 4, 5) \
  159. .reshape(1, H // 2, H // 2, 4 * C)
  160. temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1)
  161. # 1 x 156 x 4096
  162. glb_img = torch.cat([glb_img, temp_glb_GN],
  163. dim=2).reshape(1, -1, 4 * C)
  164. # (max_num_crops-1) x (12x12) x C
  165. sub_img = img_features[_bs, 1:]
  166. # 16x574x1024
  167. # get rid of padding sub_img
  168. sub_img = sub_img[:B_]
  169. sub_img = sub_img.reshape(B_, H // 2, 2, H // 2, 2, C) \
  170. .permute(0, 1, 3, 2, 4, 5).reshape(B_, -1, 4 * C)
  171. sub_img = sub_img.reshape(1, h, w, 12, 12, -1) \
  172. .permute(0, 1, 3, 2, 4, 5) \
  173. .reshape(1, h * 12, w * 12, 4 * C)
  174. temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1)
  175. sub_img = torch.cat([sub_img, temp_sub_GN],
  176. dim=2).reshape(1, -1, 4 * C)
  177. # (1, num_img_tokens, 1024*4)
  178. # glb + sub
  179. if self.hd_transform_order == 'glb_sub':
  180. output_imgs.append(
  181. torch.cat([glb_img, self.glb_GN, sub_img], dim=1))
  182. elif self.hd_transform_order == 'sub_glb':
  183. output_imgs.append(
  184. torch.cat([sub_img, self.glb_GN, glb_img], dim=1))
  185. temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12)
  186. output_len.append(temp_len)
  187. num_img_tokens = output_len
  188. img_set_tensor = []
  189. for _output_img in output_imgs:
  190. img_feature_proj = self.img_projection(
  191. _output_img.to(target_dtype))
  192. img_set_tensor.append(img_feature_proj)
  193. select = True
  194. input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
  195. hidden_states = self.wte(input_ids)
  196. if select:
  197. idx = 0
  198. for i, cnt in enumerate(num_img_tokens):
  199. hidden_states[positions[idx, 0],
  200. positions[idx, 1]:positions[idx, 1] +
  201. cnt] = (img_set_tensor[i].to(
  202. hidden_states.dtype))
  203. idx += cnt
  204. return hidden_states.squeeze(0)
  205. class Phi3VImagePixelInputs(TypedDict):
  206. type: Literal["pixel_values"]
  207. data: BatchedTensors
  208. """
  209. Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
  210. Note that `num_patches` may be different for each batch.
  211. """
  212. image_sizes: torch.Tensor
  213. """
  214. Shape: `(batch_size, 2)`
  215. This should be in `(height, width)` format.
  216. """
  217. # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
  218. def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
  219. target_height = int(np.ceil(height / padding_unit) * padding_unit)
  220. top_padding = int((target_height - height) / 2)
  221. bottom_padding = target_height - height - top_padding
  222. padded_width = width
  223. padded_height = height + top_padding + bottom_padding
  224. return padded_width, padded_height
  225. # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
  226. def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
  227. transposed = False
  228. if width < height:
  229. width, height = height, width
  230. transposed = True
  231. ratio = width / height
  232. scale = 1
  233. while scale * np.ceil(scale / ratio) <= hd_num:
  234. scale += 1
  235. scale -= 1
  236. new_width = int(scale * 336)
  237. new_height = int(new_width / ratio)
  238. padded_width, padded_height = _calc_padded_size(width=new_width,
  239. height=new_height)
  240. if transposed:
  241. padded_width, padded_height = padded_height, padded_width
  242. return padded_width, padded_height
  243. # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
  244. def get_phi3v_image_feature_size(
  245. hf_config: PretrainedConfig,
  246. *,
  247. input_height: int,
  248. input_width: int,
  249. ) -> int:
  250. num_crops = getattr(hf_config, "num_crops", 16)
  251. new_width, new_height = _calc_hd_transform_size(width=input_width,
  252. height=input_height,
  253. hd_num=num_crops)
  254. return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \
  255. + (new_height // 336 + 1) * 12
  256. def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
  257. # Result in the max possible feature size (h:w = 16:1)
  258. dummy_height, dummy_width = 8000, 50
  259. image_feature_size = get_phi3v_image_feature_size(
  260. ctx.get_hf_config(PretrainedConfig),
  261. input_height=dummy_height,
  262. input_width=dummy_width,
  263. )
  264. seq_data = dummy_seq_data_for_clip(
  265. CLIP_VIT_LARGE_PATCH14_336_CONFIG,
  266. seq_len,
  267. image_token_id=32044,
  268. image_feature_size_override=image_feature_size,
  269. )
  270. mm_data = dummy_image_for_clip(
  271. CLIP_VIT_LARGE_PATCH14_336_CONFIG,
  272. image_width_override=dummy_width,
  273. image_height_override=dummy_height,
  274. )
  275. return seq_data, mm_data
  276. # Reserve this function to also handle placeholders for additional images
  277. # [ref: PR #5820]
  278. @lru_cache
  279. def _get_image_placeholder_token_ids(model_config: ModelConfig,
  280. idx: int) -> List[int]:
  281. assert idx > 0
  282. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  283. # We need to get the token for "<", not "▁<"
  284. # https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
  285. a_token_id, = tokenizer.encode("a", add_special_tokens=False)
  286. a_token_id_, *image_placeholder_token_ids = tokenizer.encode(
  287. f"a<|image_{idx}|>", add_special_tokens=False)
  288. assert a_token_id == a_token_id_
  289. return image_placeholder_token_ids
  290. def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
  291. multi_modal_data = llm_inputs.get("multi_modal_data")
  292. if multi_modal_data is None or "image" not in multi_modal_data:
  293. return llm_inputs
  294. model_config = ctx.model_config
  295. multimodal_config = ctx.get_multimodal_config()
  296. hf_config = ctx.get_hf_config(PretrainedConfig)
  297. image_data = multi_modal_data["image"]
  298. if isinstance(image_data, Image.Image):
  299. w, h = image_data.size
  300. w, h = _calc_hd_transform_size(width=w, height=h)
  301. image_feature_size = get_phi3v_image_feature_size(hf_config,
  302. input_width=w,
  303. input_height=h)
  304. elif isinstance(image_data, torch.Tensor):
  305. raise NotImplementedError("Embeddings input is not supported yet")
  306. else:
  307. raise TypeError(f"Invalid image type: {type(image_data)}")
  308. prompt = llm_inputs.get("prompt")
  309. if prompt is None:
  310. new_prompt = None
  311. else:
  312. if prompt.count("<|image|>") > 0:
  313. logger.warning("Please follow the prompt format that is "
  314. "documented on HuggingFace which does not involve "
  315. "repeating <|image|> tokens.")
  316. elif len(re.findall(r"(<\|image_\d+\|>)+", prompt)) > 1:
  317. logger.warning("Multiple image input is not supported yet, "
  318. "so any extra image tokens will be treated "
  319. "as plain text.")
  320. new_prompt = prompt
  321. prompt_token_ids = llm_inputs["prompt_token_ids"]
  322. image_1_token_ids = _get_image_placeholder_token_ids(model_config, idx=1)
  323. new_token_ids: List[int] = []
  324. for i in range(len(prompt_token_ids) - len(image_1_token_ids) + 1):
  325. if prompt_token_ids[i:i + len(image_1_token_ids)] == image_1_token_ids:
  326. new_token_ids.append(multimodal_config.image_token_id)
  327. # No need to further scan the list since we only replace once
  328. new_token_ids.extend(prompt_token_ids[i + len(image_1_token_ids):])
  329. break
  330. else:
  331. new_token_ids.append(prompt_token_ids[i])
  332. # NOTE: Create a defensive copy of the original inputs
  333. llm_inputs = LLMInputs(prompt_token_ids=new_token_ids,
  334. prompt=new_prompt,
  335. multi_modal_data=multi_modal_data)
  336. return input_processor_for_clip(
  337. model_config,
  338. CLIP_VIT_LARGE_PATCH14_336_CONFIG,
  339. llm_inputs,
  340. image_token_id=multimodal_config.image_token_id,
  341. image_feature_size_override=image_feature_size,
  342. )
  343. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  344. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
  345. @INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
  346. class Phi3VForCausalLM(nn.Module, SupportsVision):
  347. def __init__(self,
  348. config: PretrainedConfig,
  349. vlm_config: VisionLanguageConfig,
  350. cache_config: Optional[CacheConfig] = None,
  351. quant_config: Optional[QuantizationConfig] = None) -> None:
  352. super().__init__()
  353. self.config = config
  354. self.vlm_config = vlm_config
  355. self.model = LlamaModel(config, cache_config, quant_config)
  356. # TODO: Optionally initializes this for supporting embeddings.
  357. self.vision_embed_tokens = Phi3HDImageEmbedding(
  358. vlm_config, config, self.model.embed_tokens)
  359. self.lm_head = ParallelLMHead(config.vocab_size,
  360. config.hidden_size,
  361. quant_config=quant_config)
  362. self.logits_processor = LogitsProcessor(config.vocab_size)
  363. self.sampler = Sampler()
  364. def _parse_and_validate_image_input(
  365. self, **kwargs: object) -> Optional[Phi3VImagePixelInputs]:
  366. pixel_values = kwargs.pop("pixel_values", None)
  367. image_sizes = kwargs.pop("image_sizes", None)
  368. if pixel_values is None:
  369. return None
  370. if not isinstance(pixel_values, (torch.Tensor, list)):
  371. raise ValueError("Incorrect type of pixel values. "
  372. f"Got type: {type(pixel_values)}")
  373. if not isinstance(image_sizes, torch.Tensor):
  374. raise ValueError("Incorrect type of image sizes. "
  375. f"Got type: {type(image_sizes)}")
  376. return Phi3VImagePixelInputs(type="pixel_values",
  377. data=pixel_values,
  378. image_sizes=image_sizes)
  379. def forward(self,
  380. input_ids: torch.Tensor,
  381. positions: torch.Tensor,
  382. kv_caches: List[torch.Tensor],
  383. attn_metadata: AttentionMetadata,
  384. intermediate_tensors: Optional[IntermediateTensors] = None,
  385. **kwargs: object):
  386. image_input = self._parse_and_validate_image_input(**kwargs)
  387. if image_input is not None:
  388. inputs_embeds = self.vision_embed_tokens(
  389. input_ids, image_input["data"], image_input["image_sizes"])
  390. input_ids = None
  391. else:
  392. inputs_embeds = None
  393. hidden_states = self.model(input_ids,
  394. positions,
  395. kv_caches,
  396. attn_metadata,
  397. intermediate_tensors,
  398. inputs_embeds=inputs_embeds)
  399. return hidden_states
  400. def compute_logits(self, hidden_states: torch.Tensor,
  401. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  402. logits = self.logits_processor(self.lm_head, hidden_states,
  403. sampling_metadata)
  404. return logits
  405. def sample(
  406. self,
  407. logits: torch.Tensor,
  408. sampling_metadata: SamplingMetadata,
  409. ) -> Optional[SamplerOutput]:
  410. next_tokens = self.sampler(logits, sampling_metadata)
  411. return next_tokens
  412. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  413. stacked_params_mapping = [
  414. # (param_name, shard_name, shard_id)
  415. (".qkv_proj", ".q_proj", "q"),
  416. (".qkv_proj", ".k_proj", "k"),
  417. (".qkv_proj", ".v_proj", "v"),
  418. (".gate_up_proj", ".gate_proj", 0),
  419. (".gate_up_proj", ".up_proj", 1),
  420. ]
  421. params_dict = dict(self.named_parameters())
  422. for name, loaded_weight in weights:
  423. if "rotary_emb.inv_freq" in name:
  424. continue
  425. # post_layernorm is not needed in CLIPVisionModel
  426. if "vision_model.post_layernorm" in name:
  427. continue
  428. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  429. if key_to_modify in name:
  430. name = name.replace(key_to_modify, new_key)
  431. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  432. # We only do sharding for language model
  433. # and not vision model for now.
  434. if "vision_embed_tokens" in name and self.vision_embed_tokens:
  435. continue
  436. if weight_name not in name:
  437. continue
  438. param = params_dict[name.replace(weight_name, param_name)]
  439. weight_loader = param.weight_loader
  440. weight_loader(param, loaded_weight, shard_id)
  441. break
  442. else:
  443. # Skip loading extra bias for GPTQ models.
  444. if name.endswith(".bias") and name not in params_dict:
  445. continue
  446. param = params_dict[name]
  447. weight_loader = getattr(param, "weight_loader",
  448. default_weight_loader)
  449. weight_loader(param, loaded_weight)