siglip.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659
  1. """Implementation of SiglipVisionModel intended to be only used
  2. within a vision language model."""
  3. import math
  4. from typing import Iterable, Optional, Tuple
  5. import torch
  6. from aphrodite_flash_attn import flash_attn_func
  7. from PIL import Image
  8. from torch import nn
  9. from transformers import SiglipVisionConfig
  10. from transformers.models.siglip.modeling_siglip import SiglipAttention
  11. from xformers.ops import memory_efficient_attention
  12. from aphrodite.common.config import ModelConfig
  13. from aphrodite.common.sequence import SequenceData
  14. from aphrodite.common.utils import progress_bar
  15. from aphrodite.distributed import get_tensor_model_parallel_world_size
  16. from aphrodite.inputs import LLMInputs
  17. from aphrodite.modeling.layers.activation import get_act_fn
  18. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  19. QKVParallelLinear,
  20. RowParallelLinear)
  21. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  22. VocabParallelEmbedding)
  23. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  24. from aphrodite.multimodal.image import (cached_get_tokenizer,
  25. repeat_and_pad_image_tokens)
  26. from aphrodite.quantization import QuantizationConfig
  27. def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
  28. # Since interpolation is applied, the image size need not be divisible
  29. # assert image_size % patch_size == 0
  30. return image_size // patch_size
  31. def get_siglip_num_patches(*, image_size: int, patch_size: int) -> int:
  32. grid_length = get_siglip_patch_grid_length(image_size=image_size,
  33. patch_size=patch_size)
  34. return grid_length * grid_length
  35. def get_siglip_image_feature_size(hf_config: SiglipVisionConfig) -> int:
  36. return get_siglip_num_patches(image_size=hf_config.image_size,
  37. patch_size=hf_config.patch_size)
  38. def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int:
  39. return get_siglip_image_feature_size(hf_config)
  40. def dummy_seq_data_for_siglip(
  41. hf_config: SiglipVisionConfig,
  42. seq_len: int,
  43. num_images: int,
  44. *,
  45. image_token_id: int,
  46. image_feature_size_override: Optional[int] = None,
  47. ):
  48. if image_feature_size_override is None:
  49. image_feature_size = get_siglip_image_feature_size(hf_config)
  50. else:
  51. image_feature_size = image_feature_size_override
  52. token_ids = [image_token_id] * image_feature_size * num_images
  53. token_ids += [0] * (seq_len - image_feature_size * num_images)
  54. return SequenceData(token_ids)
  55. def dummy_image_for_siglip(
  56. hf_config: SiglipVisionConfig,
  57. num_images: int,
  58. *,
  59. image_width_override: Optional[int] = None,
  60. image_height_override: Optional[int] = None,
  61. ):
  62. width = height = hf_config.image_size
  63. if image_width_override is not None:
  64. width = image_width_override
  65. if image_height_override is not None:
  66. height = image_height_override
  67. image = Image.new("RGB", (width, height), color=0)
  68. return {"image": image if num_images == 1 else [image] * num_images}
  69. def input_processor_for_siglip(
  70. model_config: ModelConfig,
  71. hf_config: SiglipVisionConfig,
  72. llm_inputs: LLMInputs,
  73. *,
  74. image_token_id: int,
  75. image_feature_size_override: Optional[int] = None,
  76. ):
  77. multi_modal_data = llm_inputs.get("multi_modal_data")
  78. if multi_modal_data is None or "image" not in multi_modal_data:
  79. return llm_inputs
  80. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  81. if image_feature_size_override is None:
  82. image_data = multi_modal_data["image"]
  83. if isinstance(image_data, Image.Image):
  84. image_feature_size = get_siglip_image_feature_size(hf_config)
  85. elif isinstance(image_data, torch.Tensor):
  86. image_feature_size = image_data.shape[0]
  87. else:
  88. raise TypeError(f"Invalid image type: {type(image_data)}")
  89. else:
  90. image_feature_size = image_feature_size_override
  91. new_prompt, new_token_ids = repeat_and_pad_image_tokens(
  92. tokenizer,
  93. llm_inputs.get("prompt"),
  94. llm_inputs["prompt_token_ids"],
  95. image_token_id=image_token_id,
  96. repeat_count=image_feature_size,
  97. )
  98. # NOTE: Create a defensive copy of the original inputs
  99. return LLMInputs(
  100. prompt_token_ids=new_token_ids,
  101. prompt=new_prompt,
  102. multi_modal_data=multi_modal_data,
  103. )
  104. # Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
  105. class SiglipVisionEmbeddings(nn.Module):
  106. def __init__(self, config: SiglipVisionConfig):
  107. super().__init__()
  108. self.config = config
  109. self.embed_dim = config.hidden_size
  110. self.image_size = config.image_size
  111. self.patch_size = config.patch_size
  112. self.patch_embedding = nn.Conv2d(
  113. in_channels=config.num_channels,
  114. out_channels=self.embed_dim,
  115. kernel_size=self.patch_size,
  116. stride=self.patch_size,
  117. padding="valid",
  118. )
  119. self.num_patches = (self.image_size // self.patch_size)**2
  120. self.num_positions = self.num_patches
  121. self.position_embedding = VocabParallelEmbedding(
  122. self.num_positions, self.embed_dim)
  123. self.register_buffer(
  124. "position_ids",
  125. torch.arange(self.num_positions, dtype=torch.int64).expand(
  126. (1, -1)),
  127. persistent=False,
  128. )
  129. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int,
  130. width: int) -> torch.Tensor:
  131. """
  132. This method is an adapted method for SigLIP (due to SigLIP not having
  133. class embedding unlike other ViTs) that allows the model to interpolate
  134. the pre-trained position encodings such that it can be usable on higher
  135. resolution images.
  136. Source:
  137. https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
  138. """
  139. position_embeddings = self.position_embedding.weight.unsqueeze(0)
  140. num_patches = embeddings.shape[1]
  141. num_positions = position_embeddings.shape[1]
  142. if num_patches == num_positions and height == width:
  143. return position_embeddings
  144. dim = embeddings.shape[-1]
  145. height = height // self.patch_size
  146. width = width // self.patch_size
  147. # we add a small number to avoid floating point error
  148. # in the interpolation
  149. # see discussion at https://github.com/facebookresearch/dino/issues/8
  150. height, width = height + 0.1, width + 0.1
  151. patch_pos_embed = position_embeddings.reshape(
  152. 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)),
  153. dim)
  154. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  155. patch_pos_embed = nn.functional.interpolate(
  156. patch_pos_embed,
  157. scale_factor=(
  158. height / math.sqrt(num_positions),
  159. width / math.sqrt(num_positions),
  160. ),
  161. mode="bicubic",
  162. align_corners=False,
  163. )
  164. if (int(height) != patch_pos_embed.shape[-2]
  165. or int(width) != patch_pos_embed.shape[-1]):
  166. raise ValueError("Width or height does not match with "
  167. "the interpolated position embeddings")
  168. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  169. return patch_pos_embed
  170. def forward(self,
  171. pixel_values: torch.Tensor,
  172. interpolate_pos_encoding: bool = False) -> torch.Tensor:
  173. _, _, height, width = pixel_values.shape
  174. target_dtype = self.patch_embedding.weight.dtype
  175. patch_embeds = self.patch_embedding(pixel_values.to(
  176. dtype=target_dtype)) # shape = [*, width, grid, grid]
  177. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  178. if interpolate_pos_encoding:
  179. embeddings = embeddings + self.interpolate_pos_encoding(
  180. embeddings, height, width)
  181. else:
  182. embeddings = embeddings + self.position_embedding(
  183. self.position_ids)
  184. return embeddings
  185. # NOTE: Not used - kept for later when we TP the ViT
  186. # TODO(ChristopherCho): Implement TP version of Attention
  187. class SiglipTPAttention(nn.Module):
  188. def __init__(
  189. self,
  190. config,
  191. quant_config: Optional[QuantizationConfig] = None,
  192. ):
  193. super().__init__()
  194. self.config = config
  195. self.embed_dim = config.hidden_size
  196. tp_size = get_tensor_model_parallel_world_size()
  197. self.total_num_heads = config.num_attention_heads
  198. if self.total_num_heads % tp_size != 0:
  199. raise ValueError(
  200. f"Number of attention heads ({self.total_num_heads}) "
  201. "must be divisible by the tensor model parallel size"
  202. f" ({tp_size}).")
  203. self.num_heads = self.total_num_heads // tp_size
  204. self.head_dim = self.embed_dim // self.total_num_heads
  205. if self.head_dim * self.total_num_heads != self.embed_dim:
  206. raise ValueError(f"embed_dim must be divisible by num_heads (got "
  207. "`embed_dim`: {self.embed_dim} and `num_heads`:"
  208. f" {self.num_heads}).")
  209. self.qkv_size = self.num_heads * self.head_dim
  210. self.scale = self.head_dim**-0.5
  211. self.dropout = config.attention_dropout
  212. self.qkv_proj = QKVParallelLinear(
  213. hidden_size=self.embed_dim,
  214. head_size=self.head_dim,
  215. total_num_heads=self.total_num_heads,
  216. quant_config=quant_config,
  217. )
  218. self.out_proj = RowParallelLinear(
  219. input_size=self.embed_dim,
  220. output_size=self.embed_dim,
  221. quant_config=quant_config,
  222. )
  223. self.attn_fn = self._basic_attention_forward
  224. def forward(
  225. self,
  226. hidden_states: torch.Tensor,
  227. ) -> torch.Tensor:
  228. """Input shape: Batch x Time x Channel"""
  229. batch_size, q_len, _ = hidden_states.size()
  230. qkv_states, _ = self.qkv_proj(hidden_states)
  231. query_states, key_states, value_states = qkv_states.split(
  232. [self.qkv_size] * 3, dim=-1)
  233. attn_output = self.attn_fn(
  234. q=query_states,
  235. k=key_states,
  236. v=value_states,
  237. batch_size=batch_size,
  238. q_len=q_len,
  239. )
  240. attn_output, _ = self.out_proj(attn_output)
  241. return attn_output
  242. def _basic_attention_forward(self, q, k, v, batch_size, q_len):
  243. q = q.view(batch_size, q_len, self.num_heads,
  244. self.head_dim).transpose(1, 2)
  245. k = k.view(batch_size, q_len, self.num_heads,
  246. self.head_dim).transpose(1, 2)
  247. v = v.view(batch_size, q_len, self.num_heads,
  248. self.head_dim).transpose(1, 2)
  249. k_v_seq_len = k.shape[-2]
  250. attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale
  251. if attn_weights.size() != (
  252. batch_size,
  253. self.num_heads,
  254. q_len,
  255. k_v_seq_len,
  256. ):
  257. raise ValueError(
  258. "Attention weights should be of size "
  259. f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
  260. f" {attn_weights.size()}")
  261. # upcast attention to fp32
  262. attn_weights = nn.functional.softmax(attn_weights,
  263. dim=-1,
  264. dtype=torch.float32).to(q.dtype)
  265. attn_weights = nn.functional.dropout(attn_weights,
  266. p=self.dropout,
  267. training=self.training)
  268. attn_output = torch.matmul(attn_weights, v)
  269. if attn_output.size() != (
  270. batch_size,
  271. self.num_heads,
  272. q_len,
  273. self.head_dim,
  274. ):
  275. raise ValueError(
  276. "`attn_output` should be of size "
  277. f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
  278. f" {attn_output.size()}")
  279. attn_output = attn_output.transpose(1, 2).contiguous()
  280. attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
  281. return attn_output
  282. # NOTE: Not used - kept for later when we TP the ViT
  283. # TODO(ChristopherCho): flash_attn_func is not working properly.
  284. # It constantly throws a CUDA error.
  285. class SiglipFlashAttention2(SiglipTPAttention):
  286. def __init__(self, *args, **kwargs):
  287. super().__init__(*args, **kwargs)
  288. self.attn_fn = self._flash_attention_forward
  289. # Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449
  290. # and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133
  291. def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args,
  292. **kwargs):
  293. """Implements the multihead softmax attention.
  294. Arguments
  295. ---------
  296. q, k, v: The tensor containing the
  297. query, key, and value. (B, S, H, D)
  298. """
  299. q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
  300. k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
  301. v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
  302. attn_output = flash_attn_func(
  303. q,
  304. k,
  305. v,
  306. dropout_p=self.dropout,
  307. causal=False,
  308. )
  309. attn_output = attn_output.reshape(batch_size, q_len,
  310. self.embed_dim).contiguous()
  311. return attn_output
  312. # NOTE: Not used - kept for later when we TP the ViT
  313. class SiglipSdpaAttention(SiglipTPAttention):
  314. def __init__(self, *args, **kwargs):
  315. super().__init__(*args, **kwargs)
  316. self.is_causal = False
  317. self.attn_fn = self._sdpa_attention_forward
  318. def _sdpa_attention_forward(self, q, k, v, batch_size, q_len):
  319. q = q.view(batch_size, q_len, self.num_heads,
  320. self.head_dim).transpose(1, 2)
  321. k = k.view(batch_size, q_len, self.num_heads,
  322. self.head_dim).transpose(1, 2)
  323. v = v.view(batch_size, q_len, self.num_heads,
  324. self.head_dim).transpose(1, 2)
  325. attn_output = torch.nn.functional.scaled_dot_product_attention(
  326. q, k, v, dropout_p=self.dropout, is_causal=False, scale=self.scale)
  327. attn_output = attn_output.transpose(1, 2).contiguous()
  328. attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
  329. return attn_output
  330. # NOTE: Not used - kept for later when we TP the ViT
  331. class SiglipxFormersAttention(SiglipTPAttention):
  332. def __init__(self, *args, **kwargs):
  333. super().__init__(*args, **kwargs)
  334. self.attn_fn = self._xformers_attention_forward
  335. def _xformers_attention_forward(self, q, k, v, batch_size, q_len):
  336. q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
  337. k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
  338. v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
  339. attn_output = memory_efficient_attention(q,
  340. k,
  341. v,
  342. p=0.0,
  343. scale=self.scale)
  344. attn_output = attn_output.reshape(batch_size, q_len,
  345. self.embed_dim).contiguous()
  346. return attn_output
  347. # NOTE: Not used - kept for later when we TP the ViT
  348. SIGLIP_ATTENTION_CLASSES = {
  349. "eager": SiglipTPAttention,
  350. "flash_attention_2": SiglipFlashAttention2,
  351. "sdpa": SiglipSdpaAttention,
  352. "xformers": SiglipxFormersAttention,
  353. }
  354. class SiglipMLP(nn.Module):
  355. def __init__(
  356. self,
  357. config,
  358. quant_config: Optional[QuantizationConfig] = None,
  359. ):
  360. super().__init__()
  361. self.config = config
  362. self.activation_fn = get_act_fn(config.hidden_act)
  363. # For quantization, we require the hidden size to be a multiple of 64
  364. quantizable = (config.hidden_size % 64 == 0
  365. and config.intermediate_size % 64 == 0)
  366. self.fc1 = ColumnParallelLinear(
  367. config.hidden_size,
  368. config.intermediate_size,
  369. quant_config=quant_config if quantizable else None,
  370. )
  371. self.fc2 = RowParallelLinear(
  372. config.intermediate_size,
  373. config.hidden_size,
  374. quant_config=quant_config if quantizable else None,
  375. )
  376. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  377. hidden_states, _ = self.fc1(hidden_states)
  378. hidden_states = self.activation_fn(hidden_states)
  379. hidden_states, _ = self.fc2(hidden_states)
  380. return hidden_states
  381. class SiglipEncoderLayer(nn.Module):
  382. def __init__(
  383. self,
  384. config: SiglipVisionConfig,
  385. quant_config: Optional[QuantizationConfig] = None,
  386. ):
  387. super().__init__()
  388. self.embed_dim = config.hidden_size
  389. # TODO(ChristopherCho): use TP'ed Attention block
  390. self.self_attn = SiglipAttention(config)
  391. self.layer_norm1 = nn.LayerNorm(self.embed_dim,
  392. eps=config.layer_norm_eps)
  393. self.mlp = SiglipMLP(
  394. config,
  395. quant_config=quant_config,
  396. )
  397. self.layer_norm2 = nn.LayerNorm(self.embed_dim,
  398. eps=config.layer_norm_eps)
  399. def forward(
  400. self,
  401. hidden_states: torch.Tensor,
  402. ) -> Tuple[torch.Tensor, None]:
  403. residual = hidden_states
  404. hidden_states = self.layer_norm1(hidden_states)
  405. hidden_states, _ = self.self_attn(hidden_states=hidden_states)
  406. hidden_states = residual + hidden_states
  407. residual = hidden_states
  408. hidden_states = self.layer_norm2(hidden_states)
  409. hidden_states = self.mlp(hidden_states)
  410. hidden_states = residual + hidden_states
  411. return hidden_states, None
  412. class SiglipEncoder(nn.Module):
  413. def __init__(
  414. self,
  415. config: SiglipVisionConfig,
  416. quant_config: Optional[QuantizationConfig] = None,
  417. num_hidden_layers_override: Optional[int] = None,
  418. ):
  419. super().__init__()
  420. self.config = config
  421. if num_hidden_layers_override is None:
  422. num_hidden_layers = config.num_hidden_layers
  423. else:
  424. num_hidden_layers = num_hidden_layers_override
  425. self.layers = nn.ModuleList([
  426. SiglipEncoderLayer(config, quant_config=quant_config)
  427. for _ in range(num_hidden_layers)
  428. ])
  429. def forward(
  430. self,
  431. inputs_embeds: torch.Tensor,
  432. ) -> torch.Tensor:
  433. hidden_states = inputs_embeds
  434. for encoder_layer in self.layers:
  435. hidden_states, _ = encoder_layer(hidden_states)
  436. return hidden_states
  437. class SiglipMultiheadAttentionPoolingHead(nn.Module):
  438. """Multihead Attention Pooling."""
  439. def __init__(
  440. self,
  441. config: SiglipVisionConfig,
  442. quant_config: Optional[QuantizationConfig] = None,
  443. ):
  444. super().__init__()
  445. self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  446. # TODO(ChristopherCho): Implement aphrodite version of MHA
  447. self.attention = torch.nn.MultiheadAttention(
  448. config.hidden_size, config.num_attention_heads, batch_first=True)
  449. self.layernorm = nn.LayerNorm(config.hidden_size,
  450. eps=config.layer_norm_eps)
  451. self.mlp = SiglipMLP(config=config, quant_config=quant_config)
  452. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  453. batch_size = hidden_state.shape[0]
  454. probe = self.probe.repeat(batch_size, 1, 1)
  455. hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
  456. residual = hidden_state
  457. hidden_state = self.layernorm(hidden_state)
  458. hidden_state = residual + self.mlp(hidden_state)
  459. return hidden_state[:, 0]
  460. class SiglipVisionTransformer(nn.Module):
  461. def __init__(
  462. self,
  463. config: SiglipVisionConfig,
  464. quant_config: Optional[QuantizationConfig] = None,
  465. num_hidden_layers_override: Optional[int] = None,
  466. ):
  467. super().__init__()
  468. self.config = config
  469. embed_dim = config.hidden_size
  470. self.embeddings = SiglipVisionEmbeddings(config)
  471. self.encoder = SiglipEncoder(
  472. config,
  473. quant_config=quant_config,
  474. num_hidden_layers_override=num_hidden_layers_override,
  475. )
  476. self.post_layernorm = nn.LayerNorm(embed_dim,
  477. eps=config.layer_norm_eps)
  478. self.use_head = (True if not hasattr(config, "vision_use_head") else
  479. config.vision_use_head)
  480. if self.use_head:
  481. self.head = SiglipMultiheadAttentionPoolingHead(
  482. config=config, quant_config=quant_config)
  483. def forward(
  484. self,
  485. pixel_values: torch.Tensor,
  486. interpolate_pos_encoding: bool = True,
  487. ) -> torch.Tensor:
  488. hidden_states = self.embeddings(
  489. pixel_values,
  490. interpolate_pos_encoding=interpolate_pos_encoding,
  491. )
  492. encoder_outputs = self.encoder(inputs_embeds=hidden_states)
  493. last_hidden_state = self.post_layernorm(encoder_outputs)
  494. # TODO: add this back when pooled_output is used in inference
  495. # if self.use_head:
  496. # pooled_output = self.head(last_hidden_state)
  497. return last_hidden_state
  498. class SiglipVisionModel(nn.Module):
  499. config_class = SiglipVisionConfig
  500. main_input_name = "pixel_values"
  501. def __init__(
  502. self,
  503. config: SiglipVisionConfig,
  504. quant_config: Optional[QuantizationConfig] = None,
  505. num_hidden_layers_override: Optional[int] = None,
  506. ):
  507. super().__init__()
  508. self.vision_model = SiglipVisionTransformer(
  509. config,
  510. quant_config,
  511. num_hidden_layers_override=num_hidden_layers_override,
  512. )
  513. def get_input_embeddings(self) -> nn.Module:
  514. return self.vision_model.embeddings.patch_embedding
  515. def forward(
  516. self,
  517. pixel_values: torch.Tensor,
  518. interpolate_pos_encoding: bool = False,
  519. ) -> torch.Tensor:
  520. return self.vision_model(
  521. pixel_values=pixel_values,
  522. interpolate_pos_encoding=interpolate_pos_encoding,
  523. )
  524. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  525. params_dict = dict(self.named_parameters())
  526. layer_count = len(self.vision_model.encoder.layers)
  527. weights_list = list(weights)
  528. for name, loaded_weight in progress_bar(weights_list,
  529. desc="Loading modules..."):
  530. # omit layers when num_hidden_layers_override is set
  531. if "vision_model.encoder.layers." in name:
  532. layer_idx = int(name.split(".")[3])
  533. if layer_idx >= layer_count:
  534. continue
  535. param = params_dict[name]
  536. weight_loader = getattr(param, "weight_loader",
  537. default_weight_loader)
  538. weight_loader(param, loaded_weight)