blip2.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663
  1. from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
  2. import torch
  3. import torch.nn as nn
  4. from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
  5. apply_chunking_to_forward)
  6. from aphrodite.attention import AttentionMetadata
  7. from aphrodite.common.config import CacheConfig, MultiModalConfig
  8. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  9. SequenceData)
  10. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  11. from aphrodite.modeling.layers.activation import get_act_fn
  12. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  13. from aphrodite.modeling.layers.sampler import Sampler
  14. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  15. from aphrodite.modeling.models.opt import OPTModel
  16. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  17. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  18. from aphrodite.quantization import QuantizationConfig
  19. from .blip import (BlipVisionModel, dummy_image_for_blip,
  20. get_max_blip_image_tokens)
  21. from .interfaces import SupportsVision
  22. from .utils import merge_vision_embeddings
  23. _KEYS_TO_MODIFY_MAPPING = {
  24. "language_model.lm_head": "lm_head",
  25. "language_model.model": "language_model",
  26. }
  27. class Blip2QFormerMultiHeadAttention(nn.Module):
  28. def __init__(
  29. self,
  30. config: Blip2QFormerConfig,
  31. *,
  32. quant_config: Optional[QuantizationConfig],
  33. cache_config: Optional[CacheConfig],
  34. is_cross_attention: bool = False,
  35. ) -> None:
  36. super().__init__()
  37. self.config = config
  38. if config.hidden_size % config.num_attention_heads != 0:
  39. raise ValueError(
  40. f"The hidden size ({config.hidden_size}) is not a multiple of "
  41. f"the number of attention heads ({config.num_attention_heads})"
  42. )
  43. self.num_attention_heads = config.num_attention_heads
  44. self.attention_head_size = (config.hidden_size //
  45. config.num_attention_heads)
  46. self.all_head_size = self.num_attention_heads * self.attention_head_size
  47. self.scaling = self.attention_head_size**-0.5
  48. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  49. if is_cross_attention:
  50. kv_hidden_size = config.encoder_hidden_size
  51. else:
  52. kv_hidden_size = config.hidden_size
  53. self.key = nn.Linear(kv_hidden_size, self.all_head_size)
  54. self.value = nn.Linear(kv_hidden_size, self.all_head_size)
  55. self.position_embedding_type = getattr(config,
  56. "position_embedding_type",
  57. "absolute")
  58. if self.position_embedding_type != "absolute":
  59. raise NotImplementedError("Unsupported position_embedding_type: "
  60. f"{self.position_embedding_type}")
  61. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  62. def transpose_for_scores(self, x):
  63. x = x.view(*x.size()[:-1], self.num_attention_heads,
  64. self.attention_head_size)
  65. return x.permute(0, 2, 1, 3)
  66. def forward(
  67. self,
  68. hidden_states: torch.Tensor,
  69. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  70. ):
  71. is_cross_attention = encoder_hidden_states is not None
  72. if is_cross_attention:
  73. key_layer = self.transpose_for_scores(
  74. self.key(encoder_hidden_states))
  75. value_layer = self.transpose_for_scores(
  76. self.value(encoder_hidden_states))
  77. else:
  78. key_layer = self.transpose_for_scores(self.key(hidden_states))
  79. value_layer = self.transpose_for_scores(self.value(hidden_states))
  80. mixed_query_layer = self.query(hidden_states)
  81. query_layer = self.transpose_for_scores(mixed_query_layer)
  82. attention_scores = torch.matmul(query_layer,
  83. key_layer.transpose(-1, -2))
  84. attention_probs = torch.softmax(attention_scores * self.scaling,
  85. dim=-1)
  86. # This is actually dropping out entire tokens to attend to, which might
  87. # seem a bit unusual, but is taken from the original Transformer paper.
  88. attention_probs_dropped = self.dropout(attention_probs)
  89. context_layer = torch.matmul(attention_probs_dropped, value_layer)
  90. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  91. context_layer = context_layer.view(*context_layer.size()[:-2],
  92. self.all_head_size)
  93. return context_layer
  94. class Blip2QFormerSelfOutput(nn.Module):
  95. def __init__(self, config: Blip2QFormerConfig) -> None:
  96. super().__init__()
  97. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  98. self.LayerNorm = nn.LayerNorm(config.hidden_size,
  99. eps=config.layer_norm_eps)
  100. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  101. def forward(
  102. self,
  103. hidden_states: torch.Tensor,
  104. input_tensor: torch.Tensor,
  105. ) -> torch.Tensor:
  106. hidden_states = self.dense(hidden_states)
  107. hidden_states = self.dropout(hidden_states)
  108. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  109. return hidden_states
  110. class Blip2QFormerAttention(nn.Module):
  111. def __init__(
  112. self,
  113. config: Blip2QFormerConfig,
  114. *,
  115. quant_config: Optional[QuantizationConfig],
  116. cache_config: Optional[CacheConfig],
  117. is_cross_attention: bool = False,
  118. ) -> None:
  119. super().__init__()
  120. self.attention = Blip2QFormerMultiHeadAttention(
  121. config,
  122. quant_config=quant_config,
  123. cache_config=cache_config,
  124. is_cross_attention=is_cross_attention,
  125. )
  126. self.output = Blip2QFormerSelfOutput(config)
  127. def forward(
  128. self,
  129. hidden_states: torch.Tensor,
  130. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  131. ) -> Tuple[torch.Tensor]:
  132. self_output = self.attention(
  133. hidden_states,
  134. encoder_hidden_states=encoder_hidden_states,
  135. )
  136. attention_output = self.output(self_output, hidden_states)
  137. return attention_output
  138. class Blip2QFormerIntermediate(nn.Module):
  139. def __init__(self, config: Blip2QFormerConfig) -> None:
  140. super().__init__()
  141. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  142. self.intermediate_act_fn = get_act_fn(config.hidden_act)
  143. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  144. hidden_states = self.dense(hidden_states)
  145. hidden_states = self.intermediate_act_fn(hidden_states)
  146. return hidden_states
  147. class Blip2QFormerOutput(nn.Module):
  148. def __init__(self, config: Blip2QFormerConfig) -> None:
  149. super().__init__()
  150. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  151. self.LayerNorm = nn.LayerNorm(config.hidden_size,
  152. eps=config.layer_norm_eps)
  153. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  154. def forward(
  155. self,
  156. hidden_states: torch.Tensor,
  157. input_tensor: torch.Tensor,
  158. ) -> torch.Tensor:
  159. hidden_states = self.dense(hidden_states)
  160. hidden_states = self.dropout(hidden_states)
  161. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  162. return hidden_states
  163. class Blip2QFormerLayer(nn.Module):
  164. def __init__(
  165. self,
  166. config: Blip2QFormerConfig,
  167. *,
  168. quant_config: Optional[QuantizationConfig],
  169. cache_config: Optional[CacheConfig],
  170. layer_idx: int,
  171. ) -> None:
  172. super().__init__()
  173. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  174. self.seq_len_dim = 1
  175. self.attention = Blip2QFormerAttention(config,
  176. quant_config=quant_config,
  177. cache_config=cache_config)
  178. self.layer_idx = layer_idx
  179. if layer_idx % config.cross_attention_frequency == 0:
  180. self.crossattention = Blip2QFormerAttention(
  181. config,
  182. quant_config=quant_config,
  183. cache_config=cache_config,
  184. is_cross_attention=True)
  185. self.has_cross_attention = True
  186. else:
  187. self.has_cross_attention = False
  188. self.intermediate_query = Blip2QFormerIntermediate(config)
  189. self.output_query = Blip2QFormerOutput(config)
  190. def forward(
  191. self,
  192. hidden_states: torch.FloatTensor,
  193. encoder_hidden_states: torch.FloatTensor,
  194. query_length: int,
  195. ):
  196. attention_output = self.attention(hidden_states)
  197. if query_length > 0:
  198. query_attention_output = attention_output[:, :query_length, :]
  199. if self.has_cross_attention:
  200. query_attention_output = self.crossattention(
  201. query_attention_output,
  202. encoder_hidden_states=encoder_hidden_states,
  203. )
  204. layer_output = apply_chunking_to_forward(
  205. self.feed_forward_chunk_query,
  206. self.chunk_size_feed_forward,
  207. self.seq_len_dim,
  208. query_attention_output,
  209. )
  210. if attention_output.shape[1] > query_length:
  211. layer_output_text = apply_chunking_to_forward(
  212. self.feed_forward_chunk,
  213. self.chunk_size_feed_forward,
  214. self.seq_len_dim,
  215. attention_output[:, query_length:, :],
  216. )
  217. layer_output = torch.cat([layer_output, layer_output_text],
  218. dim=1)
  219. else:
  220. layer_output = apply_chunking_to_forward(
  221. self.feed_forward_chunk,
  222. self.chunk_size_feed_forward,
  223. self.seq_len_dim,
  224. attention_output,
  225. )
  226. return layer_output
  227. def feed_forward_chunk(self,
  228. attention_output: torch.Tensor) -> torch.Tensor:
  229. intermediate_output = self.intermediate(attention_output)
  230. layer_output = self.output(intermediate_output, attention_output)
  231. return layer_output
  232. def feed_forward_chunk_query(
  233. self, attention_output: torch.Tensor) -> torch.Tensor:
  234. intermediate_output = self.intermediate_query(attention_output)
  235. layer_output = self.output_query(intermediate_output, attention_output)
  236. return layer_output
  237. class Blip2QFormerEncoder(nn.Module):
  238. def __init__(
  239. self,
  240. config: Blip2QFormerConfig,
  241. *,
  242. quant_config: Optional[QuantizationConfig],
  243. cache_config: Optional[CacheConfig],
  244. ) -> None:
  245. super().__init__()
  246. self.config = config
  247. self.layer = nn.ModuleList([
  248. Blip2QFormerLayer(config,
  249. quant_config=quant_config,
  250. cache_config=cache_config,
  251. layer_idx=layer_idx)
  252. for layer_idx in range(config.num_hidden_layers)
  253. ])
  254. def forward(
  255. self,
  256. hidden_states: torch.FloatTensor,
  257. encoder_hidden_states: torch.FloatTensor,
  258. query_length: int,
  259. ) -> torch.Tensor:
  260. for i in range(self.config.num_hidden_layers):
  261. layer_module = self.layer[i]
  262. hidden_states = layer_module(
  263. hidden_states,
  264. encoder_hidden_states=encoder_hidden_states,
  265. query_length=query_length,
  266. )
  267. return hidden_states
  268. # Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1025
  269. class Blip2QFormerModel(nn.Module):
  270. def __init__(
  271. self,
  272. config: Blip2QFormerConfig,
  273. *,
  274. quant_config: Optional[QuantizationConfig],
  275. cache_config: Optional[CacheConfig],
  276. ) -> None:
  277. super().__init__()
  278. self.config = config
  279. self.layernorm = nn.LayerNorm(config.hidden_size,
  280. eps=config.layer_norm_eps)
  281. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  282. self.encoder = Blip2QFormerEncoder(config,
  283. quant_config=quant_config,
  284. cache_config=cache_config)
  285. def forward(
  286. self,
  287. query_embeds: torch.FloatTensor,
  288. encoder_hidden_states: torch.FloatTensor,
  289. ) -> torch.Tensor:
  290. query_length = query_embeds.shape[1]
  291. embedding_output = self.layernorm(query_embeds)
  292. embedding_output = self.dropout(embedding_output)
  293. sequence_output = self.encoder(
  294. embedding_output,
  295. encoder_hidden_states=encoder_hidden_states,
  296. query_length=query_length,
  297. )
  298. return sequence_output
  299. class Blip2ImagePixelInputs(TypedDict):
  300. type: Literal["pixel_values"]
  301. data: torch.Tensor
  302. """Shape: (batch_size, num_channels, height, width)"""
  303. Blip2ImageInputs = Blip2ImagePixelInputs
  304. # We use this internally as placeholders since there is no image token
  305. # defined on the HuggingFace repo
  306. BLIP2_IMAGE_TOKEN = "<image>"
  307. BLIP2_IMAGE_TOKEN_ID = 50265
  308. def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
  309. return hf_config.num_query_tokens
  310. def get_max_blip2_image_tokens(ctx: InputContext):
  311. hf_config = ctx.get_hf_config(Blip2Config)
  312. vision_config = hf_config.vision_config
  313. if isinstance(vision_config, Blip2VisionConfig):
  314. return get_max_blip_image_tokens(vision_config)
  315. msg = f"Unsupported vision config: {type(vision_config)}"
  316. raise NotImplementedError(msg)
  317. def dummy_data_for_blip2(ctx: InputContext, seq_len: int):
  318. hf_config = ctx.get_hf_config(Blip2Config)
  319. vision_config = hf_config.vision_config
  320. image_feature_size = get_blip2_image_feature_size(hf_config)
  321. token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
  322. token_ids += [0] * (seq_len - image_feature_size)
  323. seq_data = SequenceData(token_ids)
  324. if isinstance(vision_config, Blip2VisionConfig):
  325. mm_data = dummy_image_for_blip(vision_config)
  326. return seq_data, mm_data
  327. msg = f"Unsupported vision config: {type(vision_config)}"
  328. raise NotImplementedError(msg)
  329. def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
  330. multi_modal_data = llm_inputs.get("multi_modal_data")
  331. if multi_modal_data is None or "image" not in multi_modal_data:
  332. return llm_inputs
  333. hf_config = ctx.get_hf_config(Blip2Config)
  334. image_feature_size = get_blip2_image_feature_size(hf_config)
  335. # The original model places image tokens at the front
  336. # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
  337. new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
  338. new_token_ids += llm_inputs["prompt_token_ids"]
  339. new_prompt = llm_inputs.get("prompt")
  340. if new_prompt is not None:
  341. new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
  342. return LLMInputs(prompt_token_ids=new_token_ids,
  343. prompt=new_prompt,
  344. multi_modal_data=multi_modal_data)
  345. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  346. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
  347. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
  348. @INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
  349. class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
  350. def __init__(self,
  351. config: Blip2Config,
  352. multimodal_config: MultiModalConfig,
  353. cache_config: Optional[CacheConfig] = None,
  354. quant_config: Optional[QuantizationConfig] = None) -> None:
  355. super().__init__()
  356. self.config = config
  357. self.multimodal_config = multimodal_config
  358. # TODO: Optionally initializes this for supporting embeddings.
  359. self.vision_model = BlipVisionModel(config.vision_config)
  360. self.query_tokens = nn.Parameter(
  361. torch.zeros(1, config.num_query_tokens,
  362. config.qformer_config.hidden_size))
  363. self.qformer = Blip2QFormerModel(config.qformer_config,
  364. cache_config=cache_config,
  365. quant_config=quant_config)
  366. self.language_projection = nn.Linear(
  367. config.qformer_config.hidden_size,
  368. config.text_config.hidden_size,
  369. bias=True,
  370. )
  371. self.quant_config = quant_config
  372. self.language_model = OPTModel(config.text_config, cache_config,
  373. quant_config)
  374. self.unpadded_vocab_size = config.text_config.vocab_size
  375. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size)
  376. self.sampler = Sampler()
  377. def get_lm_head(self):
  378. return self.language_model.decoder.embed_tokens
  379. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  380. h = w = self.config.vision_config.image_size
  381. expected_dims = (3, h, w)
  382. actual_dims = tuple(data.shape[1:])
  383. if actual_dims != expected_dims:
  384. expected_expr = ("batch_size", *map(str, expected_dims))
  385. raise ValueError(
  386. f"The expected shape of pixel values is {expected_expr}. "
  387. f"You supplied {tuple(data.shape)}.")
  388. return data
  389. def _parse_and_validate_image_input(
  390. self, **kwargs: object) -> Optional[Blip2ImageInputs]:
  391. pixel_values = kwargs.pop("pixel_values", None)
  392. if pixel_values is None:
  393. return None
  394. if not isinstance(pixel_values, torch.Tensor):
  395. raise ValueError("Incorrect type of pixel values. "
  396. f"Got type: {type(pixel_values)}")
  397. return Blip2ImagePixelInputs(
  398. type="pixel_values",
  399. data=self._validate_pixel_values(pixel_values),
  400. )
  401. def _image_pixels_to_features(self, vision_model: BlipVisionModel,
  402. pixel_values: torch.Tensor) -> torch.Tensor:
  403. # NOTE: we skip the step to select the vision feature layer since
  404. # this is already done inside the vision tower
  405. image_features = vision_model(pixel_values)
  406. return image_features
  407. def _process_image_pixels(self,
  408. inputs: Blip2ImagePixelInputs) -> torch.Tensor:
  409. assert self.vision_model is not None
  410. pixel_values = inputs["data"]
  411. return self._image_pixels_to_features(self.vision_model, pixel_values)
  412. def _process_image_input(self,
  413. image_input: Blip2ImageInputs) -> torch.Tensor:
  414. assert self.vision_model is not None
  415. image_features = self._process_image_pixels(image_input)
  416. query_tokens = self.query_tokens.expand(image_features.shape[0], -1,
  417. -1)
  418. query_output = self.qformer(
  419. query_embeds=query_tokens,
  420. encoder_hidden_states=image_features,
  421. )
  422. return self.language_projection(query_output)
  423. def forward(
  424. self,
  425. input_ids: torch.Tensor,
  426. positions: torch.Tensor,
  427. kv_caches: List[torch.Tensor],
  428. attn_metadata: AttentionMetadata,
  429. intermediate_tensors: Optional[IntermediateTensors] = None,
  430. **kwargs: object,
  431. ) -> SamplerOutput:
  432. """Run forward pass for BLIP-2.
  433. One key thing to understand is the `input_ids` already accounts for the
  434. positions of the to-be-inserted image embeddings.
  435. Concretely, consider a text prompt:
  436. `"Question: What's the content of the image? Answer:"`.
  437. Tokenizer outputs:
  438. `[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]`.
  439. To reserve space in KV cache, we have to insert placeholder tokens
  440. before they are inputted to the model, so the input processor prepends
  441. dummy tokens (denoted as `50265`), resulting in:
  442. `[50265, ..., 50265, 2, 45641, 35, ..., 31652, 35]`.
  443. We insert 32 tokens since it corresponds to the number of query
  444. embeddings outputted by the Q-Former and inputted to the language model.
  445. This way, the `positions` and `attn_metadata` are consistent
  446. with the `input_ids`.
  447. Args:
  448. input_ids: Flattened (concatenated) input_ids corresponding to a
  449. batch.
  450. pixel_values: The pixels in each input image.
  451. See also:
  452. :class:`Blip2ImageInputs`
  453. """
  454. image_input = self._parse_and_validate_image_input(**kwargs)
  455. if image_input is not None:
  456. vision_embeddings = self._process_image_input(image_input)
  457. inputs_embeds = self.language_model.get_input_embeddings(input_ids)
  458. inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
  459. vision_embeddings,
  460. BLIP2_IMAGE_TOKEN_ID)
  461. input_ids = None
  462. else:
  463. inputs_embeds = None
  464. hidden_states = self.language_model(input_ids,
  465. positions,
  466. kv_caches,
  467. attn_metadata,
  468. inputs_embeds=inputs_embeds)
  469. return hidden_states
  470. def compute_logits(self, hidden_states: torch.Tensor,
  471. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  472. logits = self.logits_processor(self.get_lm_head(), hidden_states,
  473. sampling_metadata)
  474. return logits
  475. def sample(
  476. self,
  477. logits: torch.Tensor,
  478. sampling_metadata: SamplingMetadata,
  479. ) -> Optional[SamplerOutput]:
  480. next_tokens = self.sampler(logits, sampling_metadata)
  481. return next_tokens
  482. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  483. # only doing this for language model part for now.
  484. stacked_params_mapping = [
  485. # (param_name, shard_name, shard_id)
  486. ("qkv_proj", "q_proj", "q"),
  487. ("qkv_proj", "k_proj", "k"),
  488. ("qkv_proj", "v_proj", "v"),
  489. ("gate_up_proj", "gate_proj", 0),
  490. ("gate_up_proj", "up_proj", 1),
  491. ]
  492. params_dict = dict(self.named_parameters())
  493. for name, loaded_weight in weights:
  494. if "lm_head.weight" in name:
  495. continue
  496. if "rotary_emb.inv_freq" in name:
  497. continue
  498. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  499. if key_to_modify in name:
  500. name = name.replace(key_to_modify, new_key)
  501. use_default_weight_loading = False
  502. if "vision" in name:
  503. if self.vision_model is not None:
  504. # We only do sharding for language model and
  505. # not vision model for now.
  506. use_default_weight_loading = True
  507. else:
  508. for (param_name, weight_name,
  509. shard_id) in stacked_params_mapping:
  510. if weight_name not in name:
  511. continue
  512. param = params_dict[name.replace(weight_name, param_name)]
  513. weight_loader = param.weight_loader
  514. weight_loader(param, loaded_weight, shard_id)
  515. break
  516. else:
  517. use_default_weight_loading = True
  518. if use_default_weight_loading:
  519. param = params_dict[name]
  520. weight_loader = getattr(param, "weight_loader",
  521. default_weight_loader)
  522. weight_loader(param, loaded_weight)