siglip.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660
  1. """Implementation of SiglipVisionModel intended to be only used
  2. within a vision language model."""
  3. import math
  4. from array import array
  5. from typing import Iterable, Optional, Tuple
  6. import torch
  7. from aphrodite_flash_attn import flash_attn_func
  8. from PIL import Image
  9. from torch import nn
  10. from transformers import SiglipVisionConfig
  11. from transformers.models.siglip.modeling_siglip import SiglipAttention
  12. from xformers.ops import memory_efficient_attention
  13. from aphrodite.common.config import ModelConfig
  14. from aphrodite.common.sequence import SequenceData
  15. from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
  16. from aphrodite.distributed import get_tensor_model_parallel_world_size
  17. from aphrodite.inputs import LLMInputs
  18. from aphrodite.modeling.layers.activation import get_act_fn
  19. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  20. QKVParallelLinear,
  21. RowParallelLinear)
  22. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  23. VocabParallelEmbedding)
  24. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  25. from aphrodite.multimodal.image import (cached_get_tokenizer,
  26. repeat_and_pad_image_tokens)
  27. from aphrodite.quantization import QuantizationConfig
  28. def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
  29. # Since interpolation is applied, the image size need not be divisible
  30. # assert image_size % patch_size == 0
  31. return image_size // patch_size
  32. def get_siglip_num_patches(*, image_size: int, patch_size: int) -> int:
  33. grid_length = get_siglip_patch_grid_length(image_size=image_size,
  34. patch_size=patch_size)
  35. return grid_length * grid_length
  36. def get_siglip_image_feature_size(hf_config: SiglipVisionConfig) -> int:
  37. return get_siglip_num_patches(image_size=hf_config.image_size,
  38. patch_size=hf_config.patch_size)
  39. def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int:
  40. return get_siglip_image_feature_size(hf_config)
  41. def dummy_seq_data_for_siglip(
  42. hf_config: SiglipVisionConfig,
  43. seq_len: int,
  44. num_images: int,
  45. *,
  46. image_token_id: int,
  47. image_feature_size_override: Optional[int] = None,
  48. ):
  49. if image_feature_size_override is None:
  50. image_feature_size = get_siglip_image_feature_size(hf_config)
  51. else:
  52. image_feature_size = image_feature_size_override
  53. token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  54. [image_token_id]) * image_feature_size
  55. token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  56. [0]) * (seq_len - image_feature_size)
  57. return SequenceData(token_ids)
  58. def dummy_image_for_siglip(
  59. hf_config: SiglipVisionConfig,
  60. num_images: int,
  61. *,
  62. image_width_override: Optional[int] = None,
  63. image_height_override: Optional[int] = None,
  64. ):
  65. width = height = hf_config.image_size
  66. if image_width_override is not None:
  67. width = image_width_override
  68. if image_height_override is not None:
  69. height = image_height_override
  70. image = Image.new("RGB", (width, height), color=0)
  71. return {"image": image if num_images == 1 else [image] * num_images}
  72. def input_processor_for_siglip(
  73. model_config: ModelConfig,
  74. hf_config: SiglipVisionConfig,
  75. llm_inputs: LLMInputs,
  76. *,
  77. image_token_id: int,
  78. image_feature_size_override: Optional[int] = None,
  79. ):
  80. multi_modal_data = llm_inputs.get("multi_modal_data")
  81. if multi_modal_data is None or "image" not in multi_modal_data:
  82. return llm_inputs
  83. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  84. if image_feature_size_override is None:
  85. image_data = multi_modal_data["image"]
  86. if isinstance(image_data, Image.Image):
  87. image_feature_size = get_siglip_image_feature_size(hf_config)
  88. elif isinstance(image_data, torch.Tensor):
  89. image_feature_size = image_data.shape[0]
  90. else:
  91. raise TypeError(f"Invalid image type: {type(image_data)}")
  92. else:
  93. image_feature_size = image_feature_size_override
  94. new_prompt, new_token_ids = repeat_and_pad_image_tokens(
  95. tokenizer,
  96. llm_inputs.get("prompt"),
  97. llm_inputs["prompt_token_ids"],
  98. image_token_id=image_token_id,
  99. repeat_count=image_feature_size,
  100. )
  101. # NOTE: Create a defensive copy of the original inputs
  102. return LLMInputs(
  103. prompt_token_ids=new_token_ids,
  104. prompt=new_prompt,
  105. multi_modal_data=multi_modal_data,
  106. )
  107. # Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
  108. class SiglipVisionEmbeddings(nn.Module):
  109. def __init__(self, config: SiglipVisionConfig):
  110. super().__init__()
  111. self.config = config
  112. self.embed_dim = config.hidden_size
  113. self.image_size = config.image_size
  114. self.patch_size = config.patch_size
  115. self.patch_embedding = nn.Conv2d(
  116. in_channels=config.num_channels,
  117. out_channels=self.embed_dim,
  118. kernel_size=self.patch_size,
  119. stride=self.patch_size,
  120. padding="valid",
  121. )
  122. self.num_patches = (self.image_size // self.patch_size)**2
  123. self.num_positions = self.num_patches
  124. self.position_embedding = VocabParallelEmbedding(
  125. self.num_positions, self.embed_dim)
  126. self.register_buffer(
  127. "position_ids",
  128. torch.arange(self.num_positions, dtype=torch.int64).expand(
  129. (1, -1)),
  130. persistent=False,
  131. )
  132. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int,
  133. width: int) -> torch.Tensor:
  134. """
  135. This method is an adapted method for SigLIP (due to SigLIP not having
  136. class embedding unlike other ViTs) that allows the model to interpolate
  137. the pre-trained position encodings such that it can be usable on higher
  138. resolution images.
  139. Source:
  140. https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
  141. """
  142. position_embeddings = self.position_embedding.weight.unsqueeze(0)
  143. num_patches = embeddings.shape[1]
  144. num_positions = position_embeddings.shape[1]
  145. if num_patches == num_positions and height == width:
  146. return position_embeddings
  147. dim = embeddings.shape[-1]
  148. height = height // self.patch_size
  149. width = width // self.patch_size
  150. # we add a small number to avoid floating point error
  151. # in the interpolation
  152. # see discussion at https://github.com/facebookresearch/dino/issues/8
  153. height, width = height + 0.1, width + 0.1
  154. patch_pos_embed = position_embeddings.reshape(
  155. 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)),
  156. dim)
  157. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  158. patch_pos_embed = nn.functional.interpolate(
  159. patch_pos_embed,
  160. scale_factor=(
  161. height / math.sqrt(num_positions),
  162. width / math.sqrt(num_positions),
  163. ),
  164. mode="bicubic",
  165. align_corners=False,
  166. )
  167. if (int(height) != patch_pos_embed.shape[-2]
  168. or int(width) != patch_pos_embed.shape[-1]):
  169. raise ValueError("Width or height does not match with "
  170. "the interpolated position embeddings")
  171. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  172. return patch_pos_embed
  173. def forward(self,
  174. pixel_values: torch.Tensor,
  175. interpolate_pos_encoding: bool = False) -> torch.Tensor:
  176. _, _, height, width = pixel_values.shape
  177. target_dtype = self.patch_embedding.weight.dtype
  178. patch_embeds = self.patch_embedding(pixel_values.to(
  179. dtype=target_dtype)) # shape = [*, width, grid, grid]
  180. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  181. if interpolate_pos_encoding:
  182. embeddings = embeddings + self.interpolate_pos_encoding(
  183. embeddings, height, width)
  184. else:
  185. embeddings = embeddings + self.position_embedding(
  186. self.position_ids)
  187. return embeddings
  188. # NOTE: Not used - kept for later when we TP the ViT
  189. # TODO(ChristopherCho): Implement TP version of Attention
  190. class SiglipTPAttention(nn.Module):
  191. def __init__(
  192. self,
  193. config,
  194. quant_config: Optional[QuantizationConfig] = None,
  195. ):
  196. super().__init__()
  197. self.config = config
  198. self.embed_dim = config.hidden_size
  199. tp_size = get_tensor_model_parallel_world_size()
  200. self.total_num_heads = config.num_attention_heads
  201. if self.total_num_heads % tp_size != 0:
  202. raise ValueError(
  203. f"Number of attention heads ({self.total_num_heads}) "
  204. "must be divisible by the tensor model parallel size"
  205. f" ({tp_size}).")
  206. self.num_heads = self.total_num_heads // tp_size
  207. self.head_dim = self.embed_dim // self.total_num_heads
  208. if self.head_dim * self.total_num_heads != self.embed_dim:
  209. raise ValueError(f"embed_dim must be divisible by num_heads (got "
  210. "`embed_dim`: {self.embed_dim} and `num_heads`:"
  211. f" {self.num_heads}).")
  212. self.qkv_size = self.num_heads * self.head_dim
  213. self.scale = self.head_dim**-0.5
  214. self.dropout = config.attention_dropout
  215. self.qkv_proj = QKVParallelLinear(
  216. hidden_size=self.embed_dim,
  217. head_size=self.head_dim,
  218. total_num_heads=self.total_num_heads,
  219. quant_config=quant_config,
  220. )
  221. self.out_proj = RowParallelLinear(
  222. input_size=self.embed_dim,
  223. output_size=self.embed_dim,
  224. quant_config=quant_config,
  225. )
  226. self.attn_fn = self._basic_attention_forward
  227. def forward(
  228. self,
  229. hidden_states: torch.Tensor,
  230. ) -> torch.Tensor:
  231. """Input shape: Batch x Time x Channel"""
  232. batch_size, q_len, _ = hidden_states.size()
  233. qkv_states, _ = self.qkv_proj(hidden_states)
  234. query_states, key_states, value_states = qkv_states.split(
  235. [self.qkv_size] * 3, dim=-1)
  236. attn_output = self.attn_fn(
  237. q=query_states,
  238. k=key_states,
  239. v=value_states,
  240. batch_size=batch_size,
  241. q_len=q_len,
  242. )
  243. attn_output, _ = self.out_proj(attn_output)
  244. return attn_output
  245. def _basic_attention_forward(self, q, k, v, batch_size, q_len):
  246. q = q.view(batch_size, q_len, self.num_heads,
  247. self.head_dim).transpose(1, 2)
  248. k = k.view(batch_size, q_len, self.num_heads,
  249. self.head_dim).transpose(1, 2)
  250. v = v.view(batch_size, q_len, self.num_heads,
  251. self.head_dim).transpose(1, 2)
  252. k_v_seq_len = k.shape[-2]
  253. attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale
  254. if attn_weights.size() != (
  255. batch_size,
  256. self.num_heads,
  257. q_len,
  258. k_v_seq_len,
  259. ):
  260. raise ValueError(
  261. "Attention weights should be of size "
  262. f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
  263. f" {attn_weights.size()}")
  264. # upcast attention to fp32
  265. attn_weights = nn.functional.softmax(attn_weights,
  266. dim=-1,
  267. dtype=torch.float32).to(q.dtype)
  268. attn_weights = nn.functional.dropout(attn_weights,
  269. p=self.dropout,
  270. training=self.training)
  271. attn_output = torch.matmul(attn_weights, v)
  272. if attn_output.size() != (
  273. batch_size,
  274. self.num_heads,
  275. q_len,
  276. self.head_dim,
  277. ):
  278. raise ValueError(
  279. "`attn_output` should be of size "
  280. f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
  281. f" {attn_output.size()}")
  282. attn_output = attn_output.transpose(1, 2).contiguous()
  283. attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
  284. return attn_output
  285. # NOTE: Not used - kept for later when we TP the ViT
  286. # TODO(ChristopherCho): flash_attn_func is not working properly.
  287. # It constantly throws a CUDA error.
  288. class SiglipFlashAttention2(SiglipTPAttention):
  289. def __init__(self, *args, **kwargs):
  290. super().__init__(*args, **kwargs)
  291. self.attn_fn = self._flash_attention_forward
  292. # Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449
  293. # and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133
  294. def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args,
  295. **kwargs):
  296. """Implements the multihead softmax attention.
  297. Arguments
  298. ---------
  299. q, k, v: The tensor containing the
  300. query, key, and value. (B, S, H, D)
  301. """
  302. q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
  303. k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
  304. v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
  305. attn_output = flash_attn_func(
  306. q,
  307. k,
  308. v,
  309. dropout_p=self.dropout,
  310. causal=False,
  311. )
  312. attn_output = attn_output.reshape(batch_size, q_len,
  313. self.embed_dim).contiguous()
  314. return attn_output
  315. # NOTE: Not used - kept for later when we TP the ViT
  316. class SiglipSdpaAttention(SiglipTPAttention):
  317. def __init__(self, *args, **kwargs):
  318. super().__init__(*args, **kwargs)
  319. self.is_causal = False
  320. self.attn_fn = self._sdpa_attention_forward
  321. def _sdpa_attention_forward(self, q, k, v, batch_size, q_len):
  322. q = q.view(batch_size, q_len, self.num_heads,
  323. self.head_dim).transpose(1, 2)
  324. k = k.view(batch_size, q_len, self.num_heads,
  325. self.head_dim).transpose(1, 2)
  326. v = v.view(batch_size, q_len, self.num_heads,
  327. self.head_dim).transpose(1, 2)
  328. attn_output = torch.nn.functional.scaled_dot_product_attention(
  329. q, k, v, dropout_p=self.dropout, is_causal=False, scale=self.scale)
  330. attn_output = attn_output.transpose(1, 2).contiguous()
  331. attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
  332. return attn_output
  333. # NOTE: Not used - kept for later when we TP the ViT
  334. class SiglipxFormersAttention(SiglipTPAttention):
  335. def __init__(self, *args, **kwargs):
  336. super().__init__(*args, **kwargs)
  337. self.attn_fn = self._xformers_attention_forward
  338. def _xformers_attention_forward(self, q, k, v, batch_size, q_len):
  339. q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
  340. k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
  341. v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
  342. attn_output = memory_efficient_attention(q,
  343. k,
  344. v,
  345. p=0.0,
  346. scale=self.scale)
  347. attn_output = attn_output.reshape(batch_size, q_len,
  348. self.embed_dim).contiguous()
  349. return attn_output
  350. # NOTE: Not used - kept for later when we TP the ViT
  351. SIGLIP_ATTENTION_CLASSES = {
  352. "eager": SiglipTPAttention,
  353. "flash_attention_2": SiglipFlashAttention2,
  354. "sdpa": SiglipSdpaAttention,
  355. "xformers": SiglipxFormersAttention,
  356. }
  357. class SiglipMLP(nn.Module):
  358. def __init__(
  359. self,
  360. config,
  361. quant_config: Optional[QuantizationConfig] = None,
  362. ):
  363. super().__init__()
  364. self.config = config
  365. self.activation_fn = get_act_fn(config.hidden_act)
  366. # For quantization, we require the hidden size to be a multiple of 64
  367. quantizable = (config.hidden_size % 64 == 0
  368. and config.intermediate_size % 64 == 0)
  369. self.fc1 = ColumnParallelLinear(
  370. config.hidden_size,
  371. config.intermediate_size,
  372. quant_config=quant_config if quantizable else None,
  373. )
  374. self.fc2 = RowParallelLinear(
  375. config.intermediate_size,
  376. config.hidden_size,
  377. quant_config=quant_config if quantizable else None,
  378. )
  379. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  380. hidden_states, _ = self.fc1(hidden_states)
  381. hidden_states = self.activation_fn(hidden_states)
  382. hidden_states, _ = self.fc2(hidden_states)
  383. return hidden_states
  384. class SiglipEncoderLayer(nn.Module):
  385. def __init__(
  386. self,
  387. config: SiglipVisionConfig,
  388. quant_config: Optional[QuantizationConfig] = None,
  389. ):
  390. super().__init__()
  391. self.embed_dim = config.hidden_size
  392. # TODO(ChristopherCho): use TP'ed Attention block
  393. self.self_attn = SiglipAttention(config)
  394. self.layer_norm1 = nn.LayerNorm(self.embed_dim,
  395. eps=config.layer_norm_eps)
  396. self.mlp = SiglipMLP(
  397. config,
  398. quant_config=quant_config,
  399. )
  400. self.layer_norm2 = nn.LayerNorm(self.embed_dim,
  401. eps=config.layer_norm_eps)
  402. def forward(
  403. self,
  404. hidden_states: torch.Tensor,
  405. ) -> Tuple[torch.Tensor, None]:
  406. residual = hidden_states
  407. hidden_states = self.layer_norm1(hidden_states)
  408. hidden_states, _ = self.self_attn(hidden_states=hidden_states)
  409. hidden_states = residual + hidden_states
  410. residual = hidden_states
  411. hidden_states = self.layer_norm2(hidden_states)
  412. hidden_states = self.mlp(hidden_states)
  413. hidden_states = residual + hidden_states
  414. return hidden_states, None
  415. class SiglipEncoder(nn.Module):
  416. def __init__(
  417. self,
  418. config: SiglipVisionConfig,
  419. quant_config: Optional[QuantizationConfig] = None,
  420. num_hidden_layers_override: Optional[int] = None,
  421. ):
  422. super().__init__()
  423. self.config = config
  424. if num_hidden_layers_override is None:
  425. num_hidden_layers = config.num_hidden_layers
  426. else:
  427. num_hidden_layers = num_hidden_layers_override
  428. self.layers = nn.ModuleList([
  429. SiglipEncoderLayer(config, quant_config=quant_config)
  430. for _ in range(num_hidden_layers)
  431. ])
  432. def forward(
  433. self,
  434. inputs_embeds: torch.Tensor,
  435. ) -> torch.Tensor:
  436. hidden_states = inputs_embeds
  437. for encoder_layer in self.layers:
  438. hidden_states, _ = encoder_layer(hidden_states)
  439. return hidden_states
  440. class SiglipMultiheadAttentionPoolingHead(nn.Module):
  441. """Multihead Attention Pooling."""
  442. def __init__(
  443. self,
  444. config: SiglipVisionConfig,
  445. quant_config: Optional[QuantizationConfig] = None,
  446. ):
  447. super().__init__()
  448. self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  449. # TODO(ChristopherCho): Implement aphrodite version of MHA
  450. self.attention = torch.nn.MultiheadAttention(
  451. config.hidden_size, config.num_attention_heads, batch_first=True)
  452. self.layernorm = nn.LayerNorm(config.hidden_size,
  453. eps=config.layer_norm_eps)
  454. self.mlp = SiglipMLP(config=config, quant_config=quant_config)
  455. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  456. batch_size = hidden_state.shape[0]
  457. probe = self.probe.repeat(batch_size, 1, 1)
  458. hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
  459. residual = hidden_state
  460. hidden_state = self.layernorm(hidden_state)
  461. hidden_state = residual + self.mlp(hidden_state)
  462. return hidden_state[:, 0]
  463. class SiglipVisionTransformer(nn.Module):
  464. def __init__(
  465. self,
  466. config: SiglipVisionConfig,
  467. quant_config: Optional[QuantizationConfig] = None,
  468. num_hidden_layers_override: Optional[int] = None,
  469. ):
  470. super().__init__()
  471. self.config = config
  472. embed_dim = config.hidden_size
  473. self.embeddings = SiglipVisionEmbeddings(config)
  474. self.encoder = SiglipEncoder(
  475. config,
  476. quant_config=quant_config,
  477. num_hidden_layers_override=num_hidden_layers_override,
  478. )
  479. self.post_layernorm = nn.LayerNorm(embed_dim,
  480. eps=config.layer_norm_eps)
  481. self.use_head = (True if not hasattr(config, "vision_use_head") else
  482. config.vision_use_head)
  483. if self.use_head:
  484. self.head = SiglipMultiheadAttentionPoolingHead(
  485. config=config, quant_config=quant_config)
  486. def forward(
  487. self,
  488. pixel_values: torch.Tensor,
  489. interpolate_pos_encoding: bool = True,
  490. ) -> torch.Tensor:
  491. hidden_states = self.embeddings(
  492. pixel_values,
  493. interpolate_pos_encoding=interpolate_pos_encoding,
  494. )
  495. encoder_outputs = self.encoder(inputs_embeds=hidden_states)
  496. last_hidden_state = self.post_layernorm(encoder_outputs)
  497. # TODO: add this back when pooled_output is used in inference
  498. # if self.use_head:
  499. # pooled_output = self.head(last_hidden_state)
  500. return last_hidden_state
  501. class SiglipVisionModel(nn.Module):
  502. config_class = SiglipVisionConfig
  503. main_input_name = "pixel_values"
  504. def __init__(
  505. self,
  506. config: SiglipVisionConfig,
  507. quant_config: Optional[QuantizationConfig] = None,
  508. num_hidden_layers_override: Optional[int] = None,
  509. ):
  510. super().__init__()
  511. self.vision_model = SiglipVisionTransformer(
  512. config,
  513. quant_config,
  514. num_hidden_layers_override=num_hidden_layers_override,
  515. )
  516. def get_input_embeddings(self) -> nn.Module:
  517. return self.vision_model.embeddings.patch_embedding
  518. def forward(
  519. self,
  520. pixel_values: torch.Tensor,
  521. interpolate_pos_encoding: bool = False,
  522. ) -> torch.Tensor:
  523. return self.vision_model(
  524. pixel_values=pixel_values,
  525. interpolate_pos_encoding=interpolate_pos_encoding,
  526. )
  527. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  528. params_dict = dict(self.named_parameters())
  529. layer_count = len(self.vision_model.encoder.layers)
  530. for name, loaded_weight in weights:
  531. # omit layers when num_hidden_layers_override is set
  532. if "vision_model.encoder.layers." in name:
  533. layer_idx = int(name.split(".")[3])
  534. if layer_idx >= layer_count:
  535. continue
  536. param = params_dict[name]
  537. weight_loader = getattr(param, "weight_loader",
  538. default_weight_loader)
  539. weight_loader(param, loaded_weight)