siglip.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656
  1. """Implementation of SiglipVisionModel intended to be only used
  2. within a vision language model."""
  3. import math
  4. from typing import Iterable, Optional, Tuple
  5. import torch
  6. from aphrodite_flash_attn import flash_attn_func
  7. from PIL import Image
  8. from torch import nn
  9. from transformers import SiglipVisionConfig
  10. from transformers.models.siglip.modeling_siglip import SiglipAttention
  11. from xformers.ops import memory_efficient_attention
  12. from aphrodite.common.config import ModelConfig
  13. from aphrodite.common.sequence import SequenceData
  14. from aphrodite.distributed import get_tensor_model_parallel_world_size
  15. from aphrodite.inputs import LLMInputs
  16. from aphrodite.modeling.layers.activation import get_act_fn
  17. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  18. QKVParallelLinear,
  19. RowParallelLinear)
  20. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  21. VocabParallelEmbedding)
  22. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  23. from aphrodite.multimodal.image import (cached_get_tokenizer,
  24. repeat_and_pad_image_tokens)
  25. from aphrodite.quantization import QuantizationConfig
  26. def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
  27. # Since interpolation is applied, the image size need not be divisible
  28. # assert image_size % patch_size == 0
  29. return image_size // patch_size
  30. def get_siglip_num_patches(*, image_size: int, patch_size: int) -> int:
  31. grid_length = get_siglip_patch_grid_length(image_size=image_size,
  32. patch_size=patch_size)
  33. return grid_length * grid_length
  34. def get_siglip_image_feature_size(hf_config: SiglipVisionConfig) -> int:
  35. return get_siglip_num_patches(image_size=hf_config.image_size,
  36. patch_size=hf_config.patch_size)
  37. def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int:
  38. return get_siglip_image_feature_size(hf_config)
  39. def dummy_seq_data_for_siglip(
  40. hf_config: SiglipVisionConfig,
  41. seq_len: int,
  42. num_images: int,
  43. *,
  44. image_token_id: int,
  45. image_feature_size_override: Optional[int] = None,
  46. ):
  47. if image_feature_size_override is None:
  48. image_feature_size = get_siglip_image_feature_size(hf_config)
  49. else:
  50. image_feature_size = image_feature_size_override
  51. token_ids = [image_token_id] * image_feature_size * num_images
  52. token_ids += [0] * (seq_len - image_feature_size * num_images)
  53. return SequenceData(token_ids)
  54. def dummy_image_for_siglip(
  55. hf_config: SiglipVisionConfig,
  56. num_images: int,
  57. *,
  58. image_width_override: Optional[int] = None,
  59. image_height_override: Optional[int] = None,
  60. ):
  61. width = height = hf_config.image_size
  62. if image_width_override is not None:
  63. width = image_width_override
  64. if image_height_override is not None:
  65. height = image_height_override
  66. image = Image.new("RGB", (width, height), color=0)
  67. return {"image": image if num_images == 1 else [image] * num_images}
  68. def input_processor_for_siglip(
  69. model_config: ModelConfig,
  70. hf_config: SiglipVisionConfig,
  71. llm_inputs: LLMInputs,
  72. *,
  73. image_token_id: int,
  74. image_feature_size_override: Optional[int] = None,
  75. ):
  76. multi_modal_data = llm_inputs.get("multi_modal_data")
  77. if multi_modal_data is None or "image" not in multi_modal_data:
  78. return llm_inputs
  79. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  80. if image_feature_size_override is None:
  81. image_data = multi_modal_data["image"]
  82. if isinstance(image_data, Image.Image):
  83. image_feature_size = get_siglip_image_feature_size(hf_config)
  84. elif isinstance(image_data, torch.Tensor):
  85. image_feature_size = image_data.shape[0]
  86. else:
  87. raise TypeError(f"Invalid image type: {type(image_data)}")
  88. else:
  89. image_feature_size = image_feature_size_override
  90. new_prompt, new_token_ids = repeat_and_pad_image_tokens(
  91. tokenizer,
  92. llm_inputs.get("prompt"),
  93. llm_inputs["prompt_token_ids"],
  94. image_token_id=image_token_id,
  95. repeat_count=image_feature_size,
  96. )
  97. # NOTE: Create a defensive copy of the original inputs
  98. return LLMInputs(
  99. prompt_token_ids=new_token_ids,
  100. prompt=new_prompt,
  101. multi_modal_data=multi_modal_data,
  102. )
  103. # Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
  104. class SiglipVisionEmbeddings(nn.Module):
  105. def __init__(self, config: SiglipVisionConfig):
  106. super().__init__()
  107. self.config = config
  108. self.embed_dim = config.hidden_size
  109. self.image_size = config.image_size
  110. self.patch_size = config.patch_size
  111. self.patch_embedding = nn.Conv2d(
  112. in_channels=config.num_channels,
  113. out_channels=self.embed_dim,
  114. kernel_size=self.patch_size,
  115. stride=self.patch_size,
  116. padding="valid",
  117. )
  118. self.num_patches = (self.image_size // self.patch_size)**2
  119. self.num_positions = self.num_patches
  120. self.position_embedding = VocabParallelEmbedding(
  121. self.num_positions, self.embed_dim)
  122. self.register_buffer(
  123. "position_ids",
  124. torch.arange(self.num_positions, dtype=torch.int64).expand(
  125. (1, -1)),
  126. persistent=False,
  127. )
  128. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int,
  129. width: int) -> torch.Tensor:
  130. """
  131. This method is an adapted method for SigLIP (due to SigLIP not having
  132. class embedding unlike other ViTs) that allows the model to interpolate
  133. the pre-trained position encodings such that it can be usable on higher
  134. resolution images.
  135. Source:
  136. https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
  137. """
  138. position_embeddings = self.position_embedding.weight.unsqueeze(0)
  139. num_patches = embeddings.shape[1]
  140. num_positions = position_embeddings.shape[1]
  141. if num_patches == num_positions and height == width:
  142. return position_embeddings
  143. dim = embeddings.shape[-1]
  144. height = height // self.patch_size
  145. width = width // self.patch_size
  146. # we add a small number to avoid floating point error
  147. # in the interpolation
  148. # see discussion at https://github.com/facebookresearch/dino/issues/8
  149. height, width = height + 0.1, width + 0.1
  150. patch_pos_embed = position_embeddings.reshape(
  151. 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)),
  152. dim)
  153. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  154. patch_pos_embed = nn.functional.interpolate(
  155. patch_pos_embed,
  156. scale_factor=(
  157. height / math.sqrt(num_positions),
  158. width / math.sqrt(num_positions),
  159. ),
  160. mode="bicubic",
  161. align_corners=False,
  162. )
  163. if (int(height) != patch_pos_embed.shape[-2]
  164. or int(width) != patch_pos_embed.shape[-1]):
  165. raise ValueError("Width or height does not match with "
  166. "the interpolated position embeddings")
  167. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  168. return patch_pos_embed
  169. def forward(self,
  170. pixel_values: torch.Tensor,
  171. interpolate_pos_encoding: bool = False) -> torch.Tensor:
  172. _, _, height, width = pixel_values.shape
  173. target_dtype = self.patch_embedding.weight.dtype
  174. patch_embeds = self.patch_embedding(pixel_values.to(
  175. dtype=target_dtype)) # shape = [*, width, grid, grid]
  176. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  177. if interpolate_pos_encoding:
  178. embeddings = embeddings + self.interpolate_pos_encoding(
  179. embeddings, height, width)
  180. else:
  181. embeddings = embeddings + self.position_embedding(
  182. self.position_ids)
  183. return embeddings
  184. # NOTE: Not used - kept for later when we TP the ViT
  185. # TODO(ChristopherCho): Implement TP version of Attention
  186. class SiglipTPAttention(nn.Module):
  187. def __init__(
  188. self,
  189. config,
  190. quant_config: Optional[QuantizationConfig] = None,
  191. ):
  192. super().__init__()
  193. self.config = config
  194. self.embed_dim = config.hidden_size
  195. tp_size = get_tensor_model_parallel_world_size()
  196. self.total_num_heads = config.num_attention_heads
  197. if self.total_num_heads % tp_size != 0:
  198. raise ValueError(
  199. f"Number of attention heads ({self.total_num_heads}) "
  200. "must be divisible by the tensor model parallel size"
  201. f" ({tp_size}).")
  202. self.num_heads = self.total_num_heads // tp_size
  203. self.head_dim = self.embed_dim // self.total_num_heads
  204. if self.head_dim * self.total_num_heads != self.embed_dim:
  205. raise ValueError(f"embed_dim must be divisible by num_heads (got "
  206. "`embed_dim`: {self.embed_dim} and `num_heads`:"
  207. f" {self.num_heads}).")
  208. self.qkv_size = self.num_heads * self.head_dim
  209. self.scale = self.head_dim**-0.5
  210. self.dropout = config.attention_dropout
  211. self.qkv_proj = QKVParallelLinear(
  212. hidden_size=self.embed_dim,
  213. head_size=self.head_dim,
  214. total_num_heads=self.total_num_heads,
  215. quant_config=quant_config,
  216. )
  217. self.out_proj = RowParallelLinear(
  218. input_size=self.embed_dim,
  219. output_size=self.embed_dim,
  220. quant_config=quant_config,
  221. )
  222. self.attn_fn = self._basic_attention_forward
  223. def forward(
  224. self,
  225. hidden_states: torch.Tensor,
  226. ) -> torch.Tensor:
  227. """Input shape: Batch x Time x Channel"""
  228. batch_size, q_len, _ = hidden_states.size()
  229. qkv_states, _ = self.qkv_proj(hidden_states)
  230. query_states, key_states, value_states = qkv_states.split(
  231. [self.qkv_size] * 3, dim=-1)
  232. attn_output = self.attn_fn(
  233. q=query_states,
  234. k=key_states,
  235. v=value_states,
  236. batch_size=batch_size,
  237. q_len=q_len,
  238. )
  239. attn_output, _ = self.out_proj(attn_output)
  240. return attn_output
  241. def _basic_attention_forward(self, q, k, v, batch_size, q_len):
  242. q = q.view(batch_size, q_len, self.num_heads,
  243. self.head_dim).transpose(1, 2)
  244. k = k.view(batch_size, q_len, self.num_heads,
  245. self.head_dim).transpose(1, 2)
  246. v = v.view(batch_size, q_len, self.num_heads,
  247. self.head_dim).transpose(1, 2)
  248. k_v_seq_len = k.shape[-2]
  249. attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale
  250. if attn_weights.size() != (
  251. batch_size,
  252. self.num_heads,
  253. q_len,
  254. k_v_seq_len,
  255. ):
  256. raise ValueError(
  257. "Attention weights should be of size "
  258. f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
  259. f" {attn_weights.size()}")
  260. # upcast attention to fp32
  261. attn_weights = nn.functional.softmax(attn_weights,
  262. dim=-1,
  263. dtype=torch.float32).to(q.dtype)
  264. attn_weights = nn.functional.dropout(attn_weights,
  265. p=self.dropout,
  266. training=self.training)
  267. attn_output = torch.matmul(attn_weights, v)
  268. if attn_output.size() != (
  269. batch_size,
  270. self.num_heads,
  271. q_len,
  272. self.head_dim,
  273. ):
  274. raise ValueError(
  275. "`attn_output` should be of size "
  276. f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
  277. f" {attn_output.size()}")
  278. attn_output = attn_output.transpose(1, 2).contiguous()
  279. attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
  280. return attn_output
  281. # NOTE: Not used - kept for later when we TP the ViT
  282. # TODO(ChristopherCho): flash_attn_func is not working properly.
  283. # It constantly throws a CUDA error.
  284. class SiglipFlashAttention2(SiglipTPAttention):
  285. def __init__(self, *args, **kwargs):
  286. super().__init__(*args, **kwargs)
  287. self.attn_fn = self._flash_attention_forward
  288. # Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449
  289. # and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133
  290. def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args,
  291. **kwargs):
  292. """Implements the multihead softmax attention.
  293. Arguments
  294. ---------
  295. q, k, v: The tensor containing the
  296. query, key, and value. (B, S, H, D)
  297. """
  298. q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
  299. k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
  300. v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
  301. attn_output = flash_attn_func(
  302. q,
  303. k,
  304. v,
  305. dropout_p=self.dropout,
  306. causal=False,
  307. )
  308. attn_output = attn_output.reshape(batch_size, q_len,
  309. self.embed_dim).contiguous()
  310. return attn_output
  311. # NOTE: Not used - kept for later when we TP the ViT
  312. class SiglipSdpaAttention(SiglipTPAttention):
  313. def __init__(self, *args, **kwargs):
  314. super().__init__(*args, **kwargs)
  315. self.is_causal = False
  316. self.attn_fn = self._sdpa_attention_forward
  317. def _sdpa_attention_forward(self, q, k, v, batch_size, q_len):
  318. q = q.view(batch_size, q_len, self.num_heads,
  319. self.head_dim).transpose(1, 2)
  320. k = k.view(batch_size, q_len, self.num_heads,
  321. self.head_dim).transpose(1, 2)
  322. v = v.view(batch_size, q_len, self.num_heads,
  323. self.head_dim).transpose(1, 2)
  324. attn_output = torch.nn.functional.scaled_dot_product_attention(
  325. q, k, v, dropout_p=self.dropout, is_causal=False, scale=self.scale)
  326. attn_output = attn_output.transpose(1, 2).contiguous()
  327. attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
  328. return attn_output
  329. # NOTE: Not used - kept for later when we TP the ViT
  330. class SiglipxFormersAttention(SiglipTPAttention):
  331. def __init__(self, *args, **kwargs):
  332. super().__init__(*args, **kwargs)
  333. self.attn_fn = self._xformers_attention_forward
  334. def _xformers_attention_forward(self, q, k, v, batch_size, q_len):
  335. q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
  336. k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
  337. v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
  338. attn_output = memory_efficient_attention(q,
  339. k,
  340. v,
  341. p=0.0,
  342. scale=self.scale)
  343. attn_output = attn_output.reshape(batch_size, q_len,
  344. self.embed_dim).contiguous()
  345. return attn_output
  346. # NOTE: Not used - kept for later when we TP the ViT
  347. SIGLIP_ATTENTION_CLASSES = {
  348. "eager": SiglipTPAttention,
  349. "flash_attention_2": SiglipFlashAttention2,
  350. "sdpa": SiglipSdpaAttention,
  351. "xformers": SiglipxFormersAttention,
  352. }
  353. class SiglipMLP(nn.Module):
  354. def __init__(
  355. self,
  356. config,
  357. quant_config: Optional[QuantizationConfig] = None,
  358. ):
  359. super().__init__()
  360. self.config = config
  361. self.activation_fn = get_act_fn(config.hidden_act)
  362. # For quantization, we require the hidden size to be a multiple of 64
  363. quantizable = (config.hidden_size % 64 == 0
  364. and config.intermediate_size % 64 == 0)
  365. self.fc1 = ColumnParallelLinear(
  366. config.hidden_size,
  367. config.intermediate_size,
  368. quant_config=quant_config if quantizable else None,
  369. )
  370. self.fc2 = RowParallelLinear(
  371. config.intermediate_size,
  372. config.hidden_size,
  373. quant_config=quant_config if quantizable else None,
  374. )
  375. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  376. hidden_states, _ = self.fc1(hidden_states)
  377. hidden_states = self.activation_fn(hidden_states)
  378. hidden_states, _ = self.fc2(hidden_states)
  379. return hidden_states
  380. class SiglipEncoderLayer(nn.Module):
  381. def __init__(
  382. self,
  383. config: SiglipVisionConfig,
  384. quant_config: Optional[QuantizationConfig] = None,
  385. ):
  386. super().__init__()
  387. self.embed_dim = config.hidden_size
  388. # TODO(ChristopherCho): use TP'ed Attention block
  389. self.self_attn = SiglipAttention(config)
  390. self.layer_norm1 = nn.LayerNorm(self.embed_dim,
  391. eps=config.layer_norm_eps)
  392. self.mlp = SiglipMLP(
  393. config,
  394. quant_config=quant_config,
  395. )
  396. self.layer_norm2 = nn.LayerNorm(self.embed_dim,
  397. eps=config.layer_norm_eps)
  398. def forward(
  399. self,
  400. hidden_states: torch.Tensor,
  401. ) -> Tuple[torch.Tensor, None]:
  402. residual = hidden_states
  403. hidden_states = self.layer_norm1(hidden_states)
  404. hidden_states, _ = self.self_attn(hidden_states=hidden_states)
  405. hidden_states = residual + hidden_states
  406. residual = hidden_states
  407. hidden_states = self.layer_norm2(hidden_states)
  408. hidden_states = self.mlp(hidden_states)
  409. hidden_states = residual + hidden_states
  410. return hidden_states, None
  411. class SiglipEncoder(nn.Module):
  412. def __init__(
  413. self,
  414. config: SiglipVisionConfig,
  415. quant_config: Optional[QuantizationConfig] = None,
  416. num_hidden_layers_override: Optional[int] = None,
  417. ):
  418. super().__init__()
  419. self.config = config
  420. if num_hidden_layers_override is None:
  421. num_hidden_layers = config.num_hidden_layers
  422. else:
  423. num_hidden_layers = num_hidden_layers_override
  424. self.layers = nn.ModuleList([
  425. SiglipEncoderLayer(config, quant_config=quant_config)
  426. for _ in range(num_hidden_layers)
  427. ])
  428. def forward(
  429. self,
  430. inputs_embeds: torch.Tensor,
  431. ) -> torch.Tensor:
  432. hidden_states = inputs_embeds
  433. for encoder_layer in self.layers:
  434. hidden_states, _ = encoder_layer(hidden_states)
  435. return hidden_states
  436. class SiglipMultiheadAttentionPoolingHead(nn.Module):
  437. """Multihead Attention Pooling."""
  438. def __init__(
  439. self,
  440. config: SiglipVisionConfig,
  441. quant_config: Optional[QuantizationConfig] = None,
  442. ):
  443. super().__init__()
  444. self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  445. # TODO(ChristopherCho): Implement aphrodite version of MHA
  446. self.attention = torch.nn.MultiheadAttention(
  447. config.hidden_size, config.num_attention_heads, batch_first=True)
  448. self.layernorm = nn.LayerNorm(config.hidden_size,
  449. eps=config.layer_norm_eps)
  450. self.mlp = SiglipMLP(config=config, quant_config=quant_config)
  451. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  452. batch_size = hidden_state.shape[0]
  453. probe = self.probe.repeat(batch_size, 1, 1)
  454. hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
  455. residual = hidden_state
  456. hidden_state = self.layernorm(hidden_state)
  457. hidden_state = residual + self.mlp(hidden_state)
  458. return hidden_state[:, 0]
  459. class SiglipVisionTransformer(nn.Module):
  460. def __init__(
  461. self,
  462. config: SiglipVisionConfig,
  463. quant_config: Optional[QuantizationConfig] = None,
  464. num_hidden_layers_override: Optional[int] = None,
  465. ):
  466. super().__init__()
  467. self.config = config
  468. embed_dim = config.hidden_size
  469. self.embeddings = SiglipVisionEmbeddings(config)
  470. self.encoder = SiglipEncoder(
  471. config,
  472. quant_config=quant_config,
  473. num_hidden_layers_override=num_hidden_layers_override,
  474. )
  475. self.post_layernorm = nn.LayerNorm(embed_dim,
  476. eps=config.layer_norm_eps)
  477. self.use_head = (True if not hasattr(config, "vision_use_head") else
  478. config.vision_use_head)
  479. if self.use_head:
  480. self.head = SiglipMultiheadAttentionPoolingHead(
  481. config=config, quant_config=quant_config)
  482. def forward(
  483. self,
  484. pixel_values: torch.Tensor,
  485. interpolate_pos_encoding: bool = True,
  486. ) -> torch.Tensor:
  487. hidden_states = self.embeddings(
  488. pixel_values,
  489. interpolate_pos_encoding=interpolate_pos_encoding,
  490. )
  491. encoder_outputs = self.encoder(inputs_embeds=hidden_states)
  492. last_hidden_state = self.post_layernorm(encoder_outputs)
  493. # TODO: add this back when pooled_output is used in inference
  494. # if self.use_head:
  495. # pooled_output = self.head(last_hidden_state)
  496. return last_hidden_state
  497. class SiglipVisionModel(nn.Module):
  498. config_class = SiglipVisionConfig
  499. main_input_name = "pixel_values"
  500. def __init__(
  501. self,
  502. config: SiglipVisionConfig,
  503. quant_config: Optional[QuantizationConfig] = None,
  504. num_hidden_layers_override: Optional[int] = None,
  505. ):
  506. super().__init__()
  507. self.vision_model = SiglipVisionTransformer(
  508. config,
  509. quant_config,
  510. num_hidden_layers_override=num_hidden_layers_override,
  511. )
  512. def get_input_embeddings(self) -> nn.Module:
  513. return self.vision_model.embeddings.patch_embedding
  514. def forward(
  515. self,
  516. pixel_values: torch.Tensor,
  517. interpolate_pos_encoding: bool = False,
  518. ) -> torch.Tensor:
  519. return self.vision_model(
  520. pixel_values=pixel_values,
  521. interpolate_pos_encoding=interpolate_pos_encoding,
  522. )
  523. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  524. params_dict = dict(self.named_parameters())
  525. layer_count = len(self.vision_model.encoder.layers)
  526. for name, loaded_weight in weights:
  527. # omit layers when num_hidden_layers_override is set
  528. if "vision_model.encoder.layers." in name:
  529. layer_idx = int(name.split(".")[3])
  530. if layer_idx >= layer_count:
  531. continue
  532. param = params_dict[name]
  533. weight_loader = getattr(param, "weight_loader",
  534. default_weight_loader)
  535. weight_loader(param, loaded_weight)