phi3v.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597
  1. # coding=utf-8
  2. # Copyright 2024 The PygmalionAI team.
  3. # Copyright 2024 The vLLM team.
  4. # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import re
  18. from functools import lru_cache
  19. from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
  20. import numpy as np
  21. import torch
  22. import torch.nn as nn
  23. from loguru import logger
  24. from PIL import Image
  25. from transformers import CLIPVisionConfig, PretrainedConfig
  26. from aphrodite.attention import AttentionMetadata
  27. from aphrodite.common.config import CacheConfig, ModelConfig, MultiModalConfig
  28. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  29. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  30. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  31. from aphrodite.modeling.layers.sampler import Sampler
  32. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  33. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  34. from aphrodite.modeling.models.clip import CLIPVisionModel
  35. from aphrodite.modeling.models.llama import LlamaModel
  36. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  37. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  38. from aphrodite.multimodal.image import cached_get_tokenizer
  39. from aphrodite.quantization.base_config import QuantizationConfig
  40. from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
  41. input_processor_for_clip)
  42. from .interfaces import SupportsVision
  43. from .utils import merge_vision_embeddings
  44. _KEYS_TO_MODIFY_MAPPING = {
  45. "model.vision_embed_tokens": "vision_embed_tokens",
  46. }
  47. # Cannot find the following 2 numbers from hf config.
  48. _IMAGE_TOKEN_ID = 32044
  49. # Result in the max possible feature size (h:w = 16:1)
  50. MAX_IMAGE_FEATURE_SIZE_HEIGHT = 8000
  51. MAX_IMAGE_FEATURE_SIZE_WIDTH = 50
  52. CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
  53. hidden_act="quick_gelu",
  54. hidden_size=1024,
  55. image_size=336,
  56. intermediate_size=4096,
  57. num_attention_heads=16,
  58. num_channels=3,
  59. num_hidden_layers=24,
  60. patch_size=14,
  61. projection_dim=768)
  62. class Phi3ImageEmbeddingBase(nn.Module):
  63. def __init__(self) -> None:
  64. super().__init__()
  65. self.layer_idx: int
  66. self.type_feature: str
  67. self.img_processor: CLIPVisionModel
  68. def get_img_features(self,
  69. img_embeds: torch.FloatTensor) -> torch.FloatTensor:
  70. TYPE_FEATURE = self.type_feature
  71. # NOTE: we skip the step to select the vision feature layer since
  72. # this is already done inside the img_processor
  73. img_feature = self.img_processor(img_embeds)
  74. if TYPE_FEATURE == "patch":
  75. patch_feature = img_feature[:, 1:]
  76. return patch_feature
  77. if TYPE_FEATURE == "cls_patch":
  78. return img_feature
  79. raise NotImplementedError
  80. # adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
  81. class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
  82. """Phi3 Image embedding with HD transform."""
  83. def __init__(self, config: PretrainedConfig) -> None:
  84. super().__init__()
  85. # n_embed or hidden_size
  86. hidden_size = config.n_embd if hasattr(
  87. config, 'n_embd') else config.hidden_size
  88. clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
  89. self.layer_idx = config.img_processor.get('layer_idx', -2)
  90. # Initialize the CLIP only up to the required feature layer
  91. if self.layer_idx < 0:
  92. num_hidden_layers = clip_config.num_hidden_layers + \
  93. self.layer_idx + 1
  94. else:
  95. num_hidden_layers = self.layer_idx + 1
  96. self.img_processor = CLIPVisionModel(
  97. clip_config, num_hidden_layers_override=num_hidden_layers)
  98. image_dim_out = config.img_processor['image_dim_out']
  99. self.num_img_tokens = config.img_processor['num_img_tokens']
  100. self.image_dim_out = image_dim_out
  101. # global_gn and sub_gn for hd transform, serves as line separator
  102. self.use_hd_transform = config.embd_layer.get('use_hd_transform',
  103. False)
  104. self.with_learnable_separator = config.embd_layer.get(
  105. 'with_learnable_separator', False)
  106. self.hd_transform_order = config.embd_layer.get(
  107. 'hd_transform_order', 'glb_sub')
  108. # with_hd_transform and with_learnable_separator should have same value
  109. assert self.use_hd_transform and self.with_learnable_separator
  110. # 1024 * 4, merge spatial to channel dimension
  111. self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4]))
  112. self.sub_GN = nn.Parameter(
  113. torch.empty([1, 1, 1, self.image_dim_out * 4]))
  114. dim_projection = hidden_size
  115. depth = 2
  116. layers = [nn.Linear(image_dim_out * 4, dim_projection)]
  117. for _ in range(1, depth):
  118. layers.extend(
  119. [nn.GELU(),
  120. nn.Linear(dim_projection, dim_projection)])
  121. self.img_projection = nn.Sequential(*layers)
  122. self.type_feature = config.img_processor.get('type_feature', 'patch')
  123. def forward(self, pixel_values: torch.FloatTensor,
  124. image_sizes: torch.Tensor) -> torch.FloatTensor:
  125. """
  126. process image and return vision embeddings.
  127. pixel_values: (num_images, num_crops, c, h, w)
  128. output: (num_images, num_img_tokens, hidden_size)
  129. """
  130. num_images, num_crops, c, h, w = pixel_values.shape
  131. pixel_values = pixel_values.flatten(0, 1)
  132. img_features = self.get_img_features(pixel_values)
  133. img_features = img_features.reshape(num_images, num_crops, -1,
  134. self.image_dim_out)
  135. image_features_proj = self.hd_feature_transform(
  136. img_features, image_sizes)
  137. return image_features_proj
  138. def hd_feature_transform(self, image_features, image_sizes):
  139. """
  140. image_features: (num_images, num_crops+1, 24*24, 1024)
  141. """
  142. assert (
  143. self.hd_transform_order == 'sub_glb'
  144. ), f'hd_transform_order `{self.hd_transform_order}` not implemented'
  145. if isinstance(self.img_projection, nn.Sequential):
  146. target_device = self.img_projection[0].bias.device
  147. target_dtype = self.img_projection[0].bias.dtype
  148. else: # It's a single nn.Linear layer
  149. target_device = self.img_projection.bias.device
  150. target_dtype = self.img_projection.bias.dtype
  151. global_image_features = image_features[:,
  152. 0] # (num_images, 24*24, 1024)
  153. # global feature can be viewed as a special HD case with num_crops 1x1
  154. global_image_features_hd = self.reshape_hd_patches_2x2merge(
  155. global_image_features, 1, 1)
  156. global_image_features_hd_newline = self.add_image_newline(
  157. global_image_features_hd)
  158. all_image_embeddings = []
  159. # need a for loop to process each image because of different image sizes
  160. # (patch arrangement is different for each image)
  161. for i, img_size in enumerate(image_sizes):
  162. h, w = img_size
  163. h_crop = h // 336
  164. w_crop = w // 336
  165. num_crops = h_crop * w_crop
  166. # NOTE: real num_crops is padded
  167. # (num_crops, 24*24, 1024)
  168. sub_image_features = image_features[i, 1:1 + num_crops]
  169. sub_image_features_hd = self.reshape_hd_patches_2x2merge(
  170. sub_image_features, h_crop, w_crop)
  171. sub_image_features_hd_newline = self.add_image_newline(
  172. sub_image_features_hd)
  173. # [sub features, separator, global features]
  174. all_image_embeddings.append(
  175. torch.cat([
  176. sub_image_features_hd_newline.squeeze(
  177. 0), # (h_crop*12*(w_crop*12+1), 4096)
  178. self.glb_GN.squeeze(0),
  179. global_image_features_hd_newline[i],
  180. ]))
  181. image_features_proj = self.img_projection(
  182. torch.stack(all_image_embeddings).to(target_device, target_dtype)
  183. ) # (num_images, (h_crop*12*(w_crop*12+1)+1), hidden_size)
  184. return image_features_proj
  185. def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
  186. """
  187. image_features: (num_images*num_crops, 24*24, 1024)
  188. output: (num_images, h_crop*12, w_crop*12, 4096)
  189. where h_crop*w_crop == num_crops
  190. """
  191. N, L, C = image_features.shape
  192. assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
  193. num_images = N // (h_crop * w_crop)
  194. H = int(L**0.5)
  195. image_features_hd = (
  196. image_features.reshape(N, H, H, C) # N, 24, 24, 1024
  197. .reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024
  198. .permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024
  199. .reshape(N, -1, 4 * C) # N, 144, 4096
  200. .reshape(num_images, h_crop, w_crop, H // 2, H // 2,
  201. -1) # n_img, h_crop, w_crop, 12, 12, 4096
  202. .permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096
  203. .reshape(num_images, h_crop * H // 2, w_crop * H // 2,
  204. 4 * C) # n_img, h_crop*12, w_crop*12, 4096
  205. )
  206. return image_features_hd
  207. def add_image_newline(self, image_features_hd):
  208. """
  209. image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
  210. output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
  211. """
  212. num_images, h, w, hid_dim = image_features_hd.shape
  213. # add the newline token to the HD image feature patches
  214. newline_embeddings = self.sub_GN.expand(num_images, h, -1,
  215. -1) # (n_img, h, 1, hid_dim)
  216. image_features_hd_newline = torch.cat(
  217. [image_features_hd, newline_embeddings],
  218. dim=2).reshape(num_images, -1, hid_dim)
  219. return image_features_hd_newline
  220. class Phi3VImagePixelInputs(TypedDict):
  221. type: Literal["pixel_values"]
  222. data: Union[torch.Tensor, List[torch.Tensor]]
  223. """
  224. Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
  225. Note that `num_patches` may be different for each batch, in which case
  226. the data is passed as a list instead of a batched tensor.
  227. """
  228. image_sizes: torch.Tensor
  229. """
  230. Shape: `(batch_size, 2)`
  231. This should be in `(height, width)` format.
  232. """
  233. # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
  234. def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
  235. target_height = int(np.ceil(height / padding_unit) * padding_unit)
  236. top_padding = int((target_height - height) / 2)
  237. bottom_padding = target_height - height - top_padding
  238. padded_width = width
  239. padded_height = height + top_padding + bottom_padding
  240. return padded_width, padded_height
  241. # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
  242. def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
  243. transposed = False
  244. if width < height:
  245. width, height = height, width
  246. transposed = True
  247. ratio = width / height
  248. scale = 1
  249. while scale * np.ceil(scale / ratio) <= hd_num:
  250. scale += 1
  251. scale -= 1
  252. new_width = int(scale * 336)
  253. new_height = int(new_width / ratio)
  254. padded_width, padded_height = _calc_padded_size(width=new_width,
  255. height=new_height)
  256. if transposed:
  257. padded_width, padded_height = padded_height, padded_width
  258. return padded_width, padded_height
  259. # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
  260. def get_phi3v_image_feature_size(
  261. hf_config: PretrainedConfig,
  262. *,
  263. input_height: int,
  264. input_width: int,
  265. ) -> int:
  266. num_crops = getattr(hf_config, "num_crops", 16)
  267. new_width, new_height = _calc_hd_transform_size(width=input_width,
  268. height=input_height,
  269. hd_num=num_crops)
  270. return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \
  271. + (new_height // 336 + 1) * 12
  272. def get_max_phi3v_image_tokens(ctx: InputContext):
  273. return get_phi3v_image_feature_size(
  274. ctx.get_hf_config(PretrainedConfig),
  275. input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  276. input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  277. )
  278. def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
  279. image_feature_size = get_max_phi3v_image_tokens(ctx)
  280. seq_data = dummy_seq_data_for_clip(
  281. CLIP_VIT_LARGE_PATCH14_336_CONFIG,
  282. seq_len,
  283. image_token_id=_IMAGE_TOKEN_ID,
  284. image_feature_size_override=image_feature_size,
  285. )
  286. mm_data = dummy_image_for_clip(
  287. CLIP_VIT_LARGE_PATCH14_336_CONFIG,
  288. image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  289. image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  290. )
  291. return seq_data, mm_data
  292. # Reserve this function to also handle placeholders for additional images
  293. # [ref: PR #5820]
  294. @lru_cache
  295. def _get_image_placeholder_token_ids(model_config: ModelConfig,
  296. idx: int) -> List[int]:
  297. assert idx > 0
  298. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  299. # We need to get the token for "<", not "▁<"
  300. # https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
  301. a_token_id, = tokenizer.encode("a", add_special_tokens=False)
  302. a_token_id_, *image_placeholder_token_ids = tokenizer.encode(
  303. f"a<|image_{idx}|>", add_special_tokens=False)
  304. assert a_token_id == a_token_id_
  305. return image_placeholder_token_ids
  306. def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
  307. multi_modal_data = llm_inputs.get("multi_modal_data")
  308. if multi_modal_data is None or "image" not in multi_modal_data:
  309. return llm_inputs
  310. model_config = ctx.model_config
  311. hf_config = ctx.get_hf_config(PretrainedConfig)
  312. image_data = multi_modal_data["image"]
  313. if isinstance(image_data, Image.Image):
  314. w, h = image_data.size
  315. w, h = _calc_hd_transform_size(width=w, height=h)
  316. image_feature_size = get_phi3v_image_feature_size(hf_config,
  317. input_width=w,
  318. input_height=h)
  319. elif isinstance(image_data, torch.Tensor):
  320. raise NotImplementedError("Embeddings input is not supported yet")
  321. else:
  322. raise TypeError(f"Invalid image type: {type(image_data)}")
  323. prompt = llm_inputs.get("prompt")
  324. if prompt is None:
  325. new_prompt = None
  326. else:
  327. if prompt.count("<|image|>") > 0:
  328. logger.warning("Please follow the prompt format that is "
  329. "documented on HuggingFace which does not involve "
  330. "repeating <|image|> tokens.")
  331. elif len(re.findall(r"(<\|image_\d+\|>)+", prompt)) > 1:
  332. logger.warning("Multiple image input is not supported yet, "
  333. "so any extra image tokens will be treated "
  334. "as plain text.")
  335. new_prompt = prompt
  336. prompt_token_ids = llm_inputs["prompt_token_ids"]
  337. image_1_token_ids = _get_image_placeholder_token_ids(model_config, idx=1)
  338. new_token_ids: List[int] = []
  339. for i in range(len(prompt_token_ids) - len(image_1_token_ids) + 1):
  340. if prompt_token_ids[i:i + len(image_1_token_ids)] == image_1_token_ids:
  341. new_token_ids.append(_IMAGE_TOKEN_ID)
  342. # No need to further scan the list since we only replace once
  343. new_token_ids.extend(prompt_token_ids[i + len(image_1_token_ids):])
  344. break
  345. else:
  346. new_token_ids.append(prompt_token_ids[i])
  347. # NOTE: Create a defensive copy of the original inputs
  348. llm_inputs = LLMInputs(prompt_token_ids=new_token_ids,
  349. prompt=new_prompt,
  350. multi_modal_data=multi_modal_data)
  351. return input_processor_for_clip(
  352. model_config,
  353. CLIP_VIT_LARGE_PATCH14_336_CONFIG,
  354. llm_inputs,
  355. image_token_id=_IMAGE_TOKEN_ID,
  356. image_feature_size_override=image_feature_size,
  357. )
  358. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  359. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
  360. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
  361. @INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
  362. class Phi3VForCausalLM(nn.Module, SupportsVision):
  363. def __init__(self,
  364. config: PretrainedConfig,
  365. multimodal_config: MultiModalConfig,
  366. cache_config: Optional[CacheConfig] = None,
  367. quant_config: Optional[QuantizationConfig] = None) -> None:
  368. super().__init__()
  369. self.config = config
  370. self.multimodal_config = multimodal_config
  371. self.image_token_id = _IMAGE_TOKEN_ID
  372. self.model = LlamaModel(config, cache_config, quant_config)
  373. # TODO: Optionally initializes this for supporting embeddings.
  374. self.vision_embed_tokens = Phi3HDImageEmbedding(config)
  375. self.lm_head = ParallelLMHead(config.vocab_size,
  376. config.hidden_size,
  377. quant_config=quant_config)
  378. self.logits_processor = LogitsProcessor(config.vocab_size)
  379. self.sampler = Sampler()
  380. def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
  381. if list(data.shape[1:]) != [2]:
  382. raise ValueError(
  383. f"The expected shape of image sizes is batch dimension plus "
  384. f"{[2]}. You supplied {tuple(data.shape)}.")
  385. return data
  386. def _validate_pixel_values(
  387. self, data: Union[torch.Tensor, List[torch.Tensor]]
  388. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  389. h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
  390. expected_dims = (3, h, w)
  391. def _validate_shape(d: torch.Tensor):
  392. actual_dims = tuple(d.shape[1:])
  393. if actual_dims != expected_dims:
  394. expected_expr = ("num_patches", *map(str, expected_dims))
  395. raise ValueError(
  396. "The expected shape of pixel values in each batch element "
  397. f"is {expected_expr}. You supplied {tuple(d.shape)}.")
  398. for d in data:
  399. _validate_shape(d)
  400. return data
  401. def _parse_and_validate_image_input(
  402. self, **kwargs: object) -> Optional[Phi3VImagePixelInputs]:
  403. pixel_values = kwargs.pop("pixel_values", None)
  404. image_sizes = kwargs.pop("image_sizes", None)
  405. if pixel_values is None:
  406. return None
  407. if not isinstance(pixel_values, (torch.Tensor, list)):
  408. raise ValueError("Incorrect type of pixel values. "
  409. f"Got type: {type(pixel_values)}")
  410. if not isinstance(image_sizes, torch.Tensor):
  411. raise ValueError("Incorrect type of image sizes. "
  412. f"Got type: {type(image_sizes)}")
  413. return Phi3VImagePixelInputs(
  414. type="pixel_values",
  415. data=self._validate_pixel_values(pixel_values),
  416. image_sizes=self._validate_image_sizes(image_sizes))
  417. def forward(self,
  418. input_ids: torch.Tensor,
  419. positions: torch.Tensor,
  420. kv_caches: List[torch.Tensor],
  421. attn_metadata: AttentionMetadata,
  422. intermediate_tensors: Optional[IntermediateTensors] = None,
  423. **kwargs: object):
  424. image_input = self._parse_and_validate_image_input(**kwargs)
  425. if image_input is not None:
  426. vision_embeddings = self.vision_embed_tokens(
  427. image_input["data"], image_input["image_sizes"])
  428. inputs_embeds = self.model.get_input_embeddings(input_ids)
  429. inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
  430. vision_embeddings,
  431. self.image_token_id)
  432. input_ids = None
  433. else:
  434. inputs_embeds = None
  435. hidden_states = self.model(input_ids,
  436. positions,
  437. kv_caches,
  438. attn_metadata,
  439. intermediate_tensors,
  440. inputs_embeds=inputs_embeds)
  441. return hidden_states
  442. def compute_logits(self, hidden_states: torch.Tensor,
  443. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  444. logits = self.logits_processor(self.lm_head, hidden_states,
  445. sampling_metadata)
  446. return logits
  447. def sample(
  448. self,
  449. logits: torch.Tensor,
  450. sampling_metadata: SamplingMetadata,
  451. ) -> Optional[SamplerOutput]:
  452. next_tokens = self.sampler(logits, sampling_metadata)
  453. return next_tokens
  454. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  455. stacked_params_mapping = [
  456. # (param_name, shard_name, shard_id)
  457. (".qkv_proj", ".q_proj", "q"),
  458. (".qkv_proj", ".k_proj", "k"),
  459. (".qkv_proj", ".v_proj", "v"),
  460. (".gate_up_proj", ".gate_proj", 0),
  461. (".gate_up_proj", ".up_proj", 1),
  462. ]
  463. params_dict = dict(self.named_parameters())
  464. for name, loaded_weight in weights:
  465. if "rotary_emb.inv_freq" in name:
  466. continue
  467. # post_layernorm is not needed in CLIPVisionModel
  468. if "vision_model.post_layernorm" in name:
  469. continue
  470. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  471. if key_to_modify in name:
  472. name = name.replace(key_to_modify, new_key)
  473. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  474. # We only do sharding for language model
  475. # and not vision model for now.
  476. if "vision_embed_tokens" in name and self.vision_embed_tokens:
  477. continue
  478. if weight_name not in name:
  479. continue
  480. param = params_dict[name.replace(weight_name, param_name)]
  481. weight_loader = param.weight_loader
  482. weight_loader(param, loaded_weight, shard_id)
  483. break
  484. else:
  485. # Skip loading extra bias for GPTQ models.
  486. if name.endswith(".bias") and name not in params_dict:
  487. continue
  488. if name in params_dict:
  489. param = params_dict[name]
  490. weight_loader = getattr(param, "weight_loader",
  491. default_weight_loader)
  492. weight_loader(param, loaded_weight)