siglip.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648
  1. """Implementation of SiglipVisionModel intended to be only used
  2. within a vision language model."""
  3. import math
  4. from typing import Iterable, Optional, Tuple
  5. import torch
  6. from aphrodite_flash_attn import flash_attn_func
  7. from PIL import Image
  8. from torch import nn
  9. from transformers import SiglipVisionConfig
  10. from transformers.models.siglip.modeling_siglip import SiglipAttention
  11. from xformers.ops import memory_efficient_attention
  12. from aphrodite.common.config import ModelConfig
  13. from aphrodite.common.sequence import SequenceData
  14. from aphrodite.distributed import get_tensor_model_parallel_world_size
  15. from aphrodite.inputs import LLMInputs
  16. from aphrodite.modeling.layers.activation import get_act_fn
  17. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  18. QKVParallelLinear,
  19. RowParallelLinear)
  20. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  21. VocabParallelEmbedding)
  22. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  23. from aphrodite.multimodal.image import (cached_get_tokenizer,
  24. repeat_and_pad_image_tokens)
  25. from aphrodite.quantization import QuantizationConfig
  26. def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
  27. # Since interpolation is applied, the image size need not be divisible
  28. # assert image_size % patch_size == 0
  29. return image_size // patch_size
  30. def get_siglip_num_patches(*, image_size: int, patch_size: int) -> int:
  31. grid_length = get_siglip_patch_grid_length(image_size=image_size,
  32. patch_size=patch_size)
  33. return grid_length * grid_length
  34. def get_siglip_image_feature_size(hf_config: SiglipVisionConfig) -> int:
  35. return get_siglip_num_patches(image_size=hf_config.image_size,
  36. patch_size=hf_config.patch_size)
  37. def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int:
  38. return get_siglip_image_feature_size(hf_config)
  39. def dummy_seq_data_for_siglip(
  40. hf_config: SiglipVisionConfig,
  41. seq_len: int,
  42. *,
  43. image_token_id: int,
  44. image_feature_size_override: Optional[int] = None,
  45. ):
  46. if image_feature_size_override is None:
  47. image_feature_size = get_siglip_image_feature_size(hf_config)
  48. else:
  49. image_feature_size = image_feature_size_override
  50. token_ids = [image_token_id] * image_feature_size
  51. token_ids += [0] * (seq_len - image_feature_size)
  52. return SequenceData(token_ids)
  53. def dummy_image_for_siglip(
  54. hf_config: SiglipVisionConfig,
  55. *,
  56. image_width_override: Optional[int] = None,
  57. image_height_override: Optional[int] = None,
  58. ):
  59. width = height = hf_config.image_size
  60. if image_width_override is not None:
  61. width = image_width_override
  62. if image_height_override is not None:
  63. height = image_height_override
  64. image = Image.new("RGB", (width, height), color=0)
  65. return {"image": image}
  66. def input_processor_for_siglip(
  67. model_config: ModelConfig,
  68. hf_config: SiglipVisionConfig,
  69. llm_inputs: LLMInputs,
  70. *,
  71. image_token_id: int,
  72. image_feature_size_override: Optional[int] = None,
  73. ):
  74. multi_modal_data = llm_inputs.get("multi_modal_data")
  75. if multi_modal_data is None or "image" not in multi_modal_data:
  76. return llm_inputs
  77. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  78. if image_feature_size_override is None:
  79. image_feature_size = get_siglip_image_feature_size(hf_config)
  80. else:
  81. image_feature_size = image_feature_size_override
  82. new_prompt, new_token_ids = repeat_and_pad_image_tokens(
  83. tokenizer,
  84. llm_inputs.get("prompt"),
  85. llm_inputs["prompt_token_ids"],
  86. image_token_id=image_token_id,
  87. repeat_count=image_feature_size,
  88. )
  89. # NOTE: Create a defensive copy of the original inputs
  90. return LLMInputs(
  91. prompt_token_ids=new_token_ids,
  92. prompt=new_prompt,
  93. multi_modal_data=multi_modal_data,
  94. )
  95. # Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
  96. class SiglipVisionEmbeddings(nn.Module):
  97. def __init__(self, config: SiglipVisionConfig):
  98. super().__init__()
  99. self.config = config
  100. self.embed_dim = config.hidden_size
  101. self.image_size = config.image_size
  102. self.patch_size = config.patch_size
  103. self.patch_embedding = nn.Conv2d(
  104. in_channels=config.num_channels,
  105. out_channels=self.embed_dim,
  106. kernel_size=self.patch_size,
  107. stride=self.patch_size,
  108. padding="valid",
  109. )
  110. self.num_patches = (self.image_size // self.patch_size)**2
  111. self.num_positions = self.num_patches
  112. self.position_embedding = VocabParallelEmbedding(
  113. self.num_positions, self.embed_dim)
  114. self.register_buffer(
  115. "position_ids",
  116. torch.arange(self.num_positions, dtype=torch.int64).expand(
  117. (1, -1)),
  118. persistent=False,
  119. )
  120. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int,
  121. width: int) -> torch.Tensor:
  122. """
  123. This method is an adapted method for SigLIP (due to SigLIP not having
  124. class embedding unlike other ViTs) that allows the model to interpolate
  125. the pre-trained position encodings such that it can be usable on higher
  126. resolution images.
  127. Source:
  128. https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
  129. """
  130. position_embeddings = self.position_embedding.weight.unsqueeze(0)
  131. num_patches = embeddings.shape[1]
  132. num_positions = position_embeddings.shape[1]
  133. if num_patches == num_positions and height == width:
  134. return position_embeddings
  135. dim = embeddings.shape[-1]
  136. height = height // self.patch_size
  137. width = width // self.patch_size
  138. # we add a small number to avoid floating point error
  139. # in the interpolation
  140. # see discussion at https://github.com/facebookresearch/dino/issues/8
  141. height, width = height + 0.1, width + 0.1
  142. patch_pos_embed = position_embeddings.reshape(
  143. 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)),
  144. dim)
  145. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  146. patch_pos_embed = nn.functional.interpolate(
  147. patch_pos_embed,
  148. scale_factor=(
  149. height / math.sqrt(num_positions),
  150. width / math.sqrt(num_positions),
  151. ),
  152. mode="bicubic",
  153. align_corners=False,
  154. )
  155. if (int(height) != patch_pos_embed.shape[-2]
  156. or int(width) != patch_pos_embed.shape[-1]):
  157. raise ValueError("Width or height does not match with "
  158. "the interpolated position embeddings")
  159. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  160. return patch_pos_embed
  161. def forward(self,
  162. pixel_values: torch.Tensor,
  163. interpolate_pos_encoding: bool = False) -> torch.Tensor:
  164. _, _, height, width = pixel_values.shape
  165. target_dtype = self.patch_embedding.weight.dtype
  166. patch_embeds = self.patch_embedding(pixel_values.to(
  167. dtype=target_dtype)) # shape = [*, width, grid, grid]
  168. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  169. if interpolate_pos_encoding:
  170. embeddings = embeddings + self.interpolate_pos_encoding(
  171. embeddings, height, width)
  172. else:
  173. embeddings = embeddings + self.position_embedding(
  174. self.position_ids)
  175. return embeddings
  176. # NOTE: Not used - kept for later when we TP the ViT
  177. # TODO(ChristopherCho): Implement TP version of Attention
  178. class SiglipTPAttention(nn.Module):
  179. def __init__(
  180. self,
  181. config,
  182. quant_config: Optional[QuantizationConfig] = None,
  183. ):
  184. super().__init__()
  185. self.config = config
  186. self.embed_dim = config.hidden_size
  187. tp_size = get_tensor_model_parallel_world_size()
  188. self.total_num_heads = config.num_attention_heads
  189. if self.total_num_heads % tp_size != 0:
  190. raise ValueError(
  191. f"Number of attention heads ({self.total_num_heads}) "
  192. "must be divisible by the tensor model parallel size"
  193. f" ({tp_size}).")
  194. self.num_heads = self.total_num_heads // tp_size
  195. self.head_dim = self.embed_dim // self.total_num_heads
  196. if self.head_dim * self.total_num_heads != self.embed_dim:
  197. raise ValueError(f"embed_dim must be divisible by num_heads (got "
  198. "`embed_dim`: {self.embed_dim} and `num_heads`:"
  199. f" {self.num_heads}).")
  200. self.qkv_size = self.num_heads * self.head_dim
  201. self.scale = self.head_dim**-0.5
  202. self.dropout = config.attention_dropout
  203. self.qkv_proj = QKVParallelLinear(
  204. hidden_size=self.embed_dim,
  205. head_size=self.head_dim,
  206. total_num_heads=self.total_num_heads,
  207. quant_config=quant_config,
  208. )
  209. self.out_proj = RowParallelLinear(
  210. input_size=self.embed_dim,
  211. output_size=self.embed_dim,
  212. quant_config=quant_config,
  213. )
  214. self.attn_fn = self._basic_attention_forward
  215. def forward(
  216. self,
  217. hidden_states: torch.Tensor,
  218. ) -> torch.Tensor:
  219. """Input shape: Batch x Time x Channel"""
  220. batch_size, q_len, _ = hidden_states.size()
  221. qkv_states, _ = self.qkv_proj(hidden_states)
  222. query_states, key_states, value_states = qkv_states.split(
  223. [self.qkv_size] * 3, dim=-1)
  224. attn_output = self.attn_fn(
  225. q=query_states,
  226. k=key_states,
  227. v=value_states,
  228. batch_size=batch_size,
  229. q_len=q_len,
  230. )
  231. attn_output, _ = self.out_proj(attn_output)
  232. return attn_output
  233. def _basic_attention_forward(self, q, k, v, batch_size, q_len):
  234. q = q.view(batch_size, q_len, self.num_heads,
  235. self.head_dim).transpose(1, 2)
  236. k = k.view(batch_size, q_len, self.num_heads,
  237. self.head_dim).transpose(1, 2)
  238. v = v.view(batch_size, q_len, self.num_heads,
  239. self.head_dim).transpose(1, 2)
  240. k_v_seq_len = k.shape[-2]
  241. attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale
  242. if attn_weights.size() != (
  243. batch_size,
  244. self.num_heads,
  245. q_len,
  246. k_v_seq_len,
  247. ):
  248. raise ValueError(
  249. "Attention weights should be of size "
  250. f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
  251. f" {attn_weights.size()}")
  252. # upcast attention to fp32
  253. attn_weights = nn.functional.softmax(attn_weights,
  254. dim=-1,
  255. dtype=torch.float32).to(q.dtype)
  256. attn_weights = nn.functional.dropout(attn_weights,
  257. p=self.dropout,
  258. training=self.training)
  259. attn_output = torch.matmul(attn_weights, v)
  260. if attn_output.size() != (
  261. batch_size,
  262. self.num_heads,
  263. q_len,
  264. self.head_dim,
  265. ):
  266. raise ValueError(
  267. "`attn_output` should be of size "
  268. f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
  269. f" {attn_output.size()}")
  270. attn_output = attn_output.transpose(1, 2).contiguous()
  271. attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
  272. return attn_output
  273. # NOTE: Not used - kept for later when we TP the ViT
  274. # TODO(ChristopherCho): flash_attn_func is not working properly.
  275. # It constantly throws a CUDA error.
  276. class SiglipFlashAttention2(SiglipTPAttention):
  277. def __init__(self, *args, **kwargs):
  278. super().__init__(*args, **kwargs)
  279. self.attn_fn = self._flash_attention_forward
  280. # Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449
  281. # and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133
  282. def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args,
  283. **kwargs):
  284. """Implements the multihead softmax attention.
  285. Arguments
  286. ---------
  287. q, k, v: The tensor containing the
  288. query, key, and value. (B, S, H, D)
  289. """
  290. q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
  291. k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
  292. v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
  293. attn_output = flash_attn_func(
  294. q,
  295. k,
  296. v,
  297. dropout_p=self.dropout,
  298. causal=False,
  299. )
  300. attn_output = attn_output.reshape(batch_size, q_len,
  301. self.embed_dim).contiguous()
  302. return attn_output
  303. # NOTE: Not used - kept for later when we TP the ViT
  304. class SiglipSdpaAttention(SiglipTPAttention):
  305. def __init__(self, *args, **kwargs):
  306. super().__init__(*args, **kwargs)
  307. self.is_causal = False
  308. self.attn_fn = self._sdpa_attention_forward
  309. def _sdpa_attention_forward(self, q, k, v, batch_size, q_len):
  310. q = q.view(batch_size, q_len, self.num_heads,
  311. self.head_dim).transpose(1, 2)
  312. k = k.view(batch_size, q_len, self.num_heads,
  313. self.head_dim).transpose(1, 2)
  314. v = v.view(batch_size, q_len, self.num_heads,
  315. self.head_dim).transpose(1, 2)
  316. attn_output = torch.nn.functional.scaled_dot_product_attention(
  317. q, k, v, dropout_p=self.dropout, is_causal=False, scale=self.scale)
  318. attn_output = attn_output.transpose(1, 2).contiguous()
  319. attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
  320. return attn_output
  321. # NOTE: Not used - kept for later when we TP the ViT
  322. class SiglipxFormersAttention(SiglipTPAttention):
  323. def __init__(self, *args, **kwargs):
  324. super().__init__(*args, **kwargs)
  325. self.attn_fn = self._xformers_attention_forward
  326. def _xformers_attention_forward(self, q, k, v, batch_size, q_len):
  327. q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
  328. k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
  329. v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
  330. attn_output = memory_efficient_attention(q,
  331. k,
  332. v,
  333. p=0.0,
  334. scale=self.scale)
  335. attn_output = attn_output.reshape(batch_size, q_len,
  336. self.embed_dim).contiguous()
  337. return attn_output
  338. # NOTE: Not used - kept for later when we TP the ViT
  339. SIGLIP_ATTENTION_CLASSES = {
  340. "eager": SiglipTPAttention,
  341. "flash_attention_2": SiglipFlashAttention2,
  342. "sdpa": SiglipSdpaAttention,
  343. "xformers": SiglipxFormersAttention,
  344. }
  345. class SiglipMLP(nn.Module):
  346. def __init__(
  347. self,
  348. config,
  349. quant_config: Optional[QuantizationConfig] = None,
  350. ):
  351. super().__init__()
  352. self.config = config
  353. self.activation_fn = get_act_fn(config.hidden_act)
  354. # For quantization, we require the hidden size to be a multiple of 64
  355. quantizable = (config.hidden_size % 64 == 0
  356. and config.intermediate_size % 64 == 0)
  357. self.fc1 = ColumnParallelLinear(
  358. config.hidden_size,
  359. config.intermediate_size,
  360. quant_config=quant_config if quantizable else None,
  361. )
  362. self.fc2 = RowParallelLinear(
  363. config.intermediate_size,
  364. config.hidden_size,
  365. quant_config=quant_config if quantizable else None,
  366. )
  367. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  368. hidden_states, _ = self.fc1(hidden_states)
  369. hidden_states = self.activation_fn(hidden_states)
  370. hidden_states, _ = self.fc2(hidden_states)
  371. return hidden_states
  372. class SiglipEncoderLayer(nn.Module):
  373. def __init__(
  374. self,
  375. config: SiglipVisionConfig,
  376. quant_config: Optional[QuantizationConfig] = None,
  377. ):
  378. super().__init__()
  379. self.embed_dim = config.hidden_size
  380. # TODO(ChristopherCho): use TP'ed Attention block
  381. self.self_attn = SiglipAttention(config)
  382. self.layer_norm1 = nn.LayerNorm(self.embed_dim,
  383. eps=config.layer_norm_eps)
  384. self.mlp = SiglipMLP(
  385. config,
  386. quant_config=quant_config,
  387. )
  388. self.layer_norm2 = nn.LayerNorm(self.embed_dim,
  389. eps=config.layer_norm_eps)
  390. def forward(
  391. self,
  392. hidden_states: torch.Tensor,
  393. ) -> Tuple[torch.Tensor, None]:
  394. residual = hidden_states
  395. hidden_states = self.layer_norm1(hidden_states)
  396. hidden_states, _ = self.self_attn(hidden_states=hidden_states)
  397. hidden_states = residual + hidden_states
  398. residual = hidden_states
  399. hidden_states = self.layer_norm2(hidden_states)
  400. hidden_states = self.mlp(hidden_states)
  401. hidden_states = residual + hidden_states
  402. return hidden_states, None
  403. class SiglipEncoder(nn.Module):
  404. def __init__(
  405. self,
  406. config: SiglipVisionConfig,
  407. quant_config: Optional[QuantizationConfig] = None,
  408. num_hidden_layers_override: Optional[int] = None,
  409. ):
  410. super().__init__()
  411. self.config = config
  412. if num_hidden_layers_override is None:
  413. num_hidden_layers = config.num_hidden_layers
  414. else:
  415. num_hidden_layers = num_hidden_layers_override
  416. self.layers = nn.ModuleList([
  417. SiglipEncoderLayer(config, quant_config=quant_config)
  418. for _ in range(num_hidden_layers)
  419. ])
  420. def forward(
  421. self,
  422. inputs_embeds: torch.Tensor,
  423. ) -> torch.Tensor:
  424. hidden_states = inputs_embeds
  425. for encoder_layer in self.layers:
  426. hidden_states, _ = encoder_layer(hidden_states)
  427. return hidden_states
  428. class SiglipMultiheadAttentionPoolingHead(nn.Module):
  429. """Multihead Attention Pooling."""
  430. def __init__(
  431. self,
  432. config: SiglipVisionConfig,
  433. quant_config: Optional[QuantizationConfig] = None,
  434. ):
  435. super().__init__()
  436. self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  437. # TODO(ChristopherCho): Implement aphrodite version of MHA
  438. self.attention = torch.nn.MultiheadAttention(
  439. config.hidden_size, config.num_attention_heads, batch_first=True)
  440. self.layernorm = nn.LayerNorm(config.hidden_size,
  441. eps=config.layer_norm_eps)
  442. self.mlp = SiglipMLP(config=config, quant_config=quant_config)
  443. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  444. batch_size = hidden_state.shape[0]
  445. probe = self.probe.repeat(batch_size, 1, 1)
  446. hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
  447. residual = hidden_state
  448. hidden_state = self.layernorm(hidden_state)
  449. hidden_state = residual + self.mlp(hidden_state)
  450. return hidden_state[:, 0]
  451. class SiglipVisionTransformer(nn.Module):
  452. def __init__(
  453. self,
  454. config: SiglipVisionConfig,
  455. quant_config: Optional[QuantizationConfig] = None,
  456. num_hidden_layers_override: Optional[int] = None,
  457. ):
  458. super().__init__()
  459. self.config = config
  460. embed_dim = config.hidden_size
  461. self.embeddings = SiglipVisionEmbeddings(config)
  462. self.encoder = SiglipEncoder(
  463. config,
  464. quant_config=quant_config,
  465. num_hidden_layers_override=num_hidden_layers_override,
  466. )
  467. self.post_layernorm = nn.LayerNorm(embed_dim,
  468. eps=config.layer_norm_eps)
  469. self.use_head = (True if not hasattr(config, "vision_use_head") else
  470. config.vision_use_head)
  471. if self.use_head:
  472. self.head = SiglipMultiheadAttentionPoolingHead(
  473. config=config, quant_config=quant_config)
  474. def forward(
  475. self,
  476. pixel_values: torch.Tensor,
  477. interpolate_pos_encoding: bool = True,
  478. ) -> torch.Tensor:
  479. hidden_states = self.embeddings(
  480. pixel_values,
  481. interpolate_pos_encoding=interpolate_pos_encoding,
  482. )
  483. encoder_outputs = self.encoder(inputs_embeds=hidden_states)
  484. last_hidden_state = self.post_layernorm(encoder_outputs)
  485. # TODO: add this back when pooled_output is used in inference
  486. # if self.use_head:
  487. # pooled_output = self.head(last_hidden_state)
  488. return last_hidden_state
  489. class SiglipVisionModel(nn.Module):
  490. config_class = SiglipVisionConfig
  491. main_input_name = "pixel_values"
  492. def __init__(
  493. self,
  494. config: SiglipVisionConfig,
  495. quant_config: Optional[QuantizationConfig] = None,
  496. num_hidden_layers_override: Optional[int] = None,
  497. ):
  498. super().__init__()
  499. self.vision_model = SiglipVisionTransformer(
  500. config,
  501. quant_config,
  502. num_hidden_layers_override=num_hidden_layers_override,
  503. )
  504. def get_input_embeddings(self) -> nn.Module:
  505. return self.vision_model.embeddings.patch_embedding
  506. def forward(
  507. self,
  508. pixel_values: torch.Tensor,
  509. interpolate_pos_encoding: bool = False,
  510. ) -> torch.Tensor:
  511. return self.vision_model(
  512. pixel_values=pixel_values,
  513. interpolate_pos_encoding=interpolate_pos_encoding,
  514. )
  515. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  516. params_dict = dict(self.named_parameters())
  517. layer_count = len(self.vision_model.encoder.layers)
  518. for name, loaded_weight in weights:
  519. # omit layers when num_hidden_layers_override is set
  520. if "vision_model.encoder.layers." in name:
  521. layer_idx = int(name.split(".")[3])
  522. if layer_idx >= layer_count:
  523. continue
  524. param = params_dict[name]
  525. weight_loader = getattr(param, "weight_loader",
  526. default_weight_loader)
  527. weight_loader(param, loaded_weight)