1
0

na_vit.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803
  1. import logging
  2. import math
  3. import os
  4. import warnings
  5. from typing import Optional, Tuple, Union
  6. import numpy as np
  7. import torch
  8. import torch.nn.functional as F
  9. from torch import nn
  10. from torch.nn.init import _calculate_fan_in_and_fan_out
  11. from transformers.activations import ACT2FN
  12. from transformers.configuration_utils import PretrainedConfig
  13. from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
  14. from transformers.modeling_outputs import (BaseModelOutput,
  15. BaseModelOutputWithPooling)
  16. from transformers.modeling_utils import PreTrainedModel
  17. from transformers.utils import (ModelOutput, is_flash_attn_2_available,
  18. replace_return_docstrings)
  19. logger = logging.getLogger("aphrodite")
  20. # For Siglip: copied from
  21. # HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes
  22. # Remove hints as there's little possibility to change these code.
  23. class SiglipVisionConfig(PretrainedConfig):
  24. model_type = "siglip_vision_model"
  25. def __init__(
  26. self,
  27. hidden_size=768,
  28. intermediate_size=3072,
  29. num_hidden_layers=12,
  30. num_attention_heads=12,
  31. num_channels=3,
  32. image_size=224,
  33. patch_size=16,
  34. hidden_act="gelu_pytorch_tanh",
  35. layer_norm_eps=1e-6,
  36. attention_dropout=0.0,
  37. **kwargs,
  38. ):
  39. super().__init__(**kwargs)
  40. self.hidden_size = hidden_size
  41. self.intermediate_size = intermediate_size
  42. self.num_hidden_layers = num_hidden_layers
  43. self.num_attention_heads = num_attention_heads
  44. self.num_channels = num_channels
  45. self.patch_size = patch_size
  46. self.image_size = image_size
  47. self.attention_dropout = attention_dropout
  48. self.layer_norm_eps = layer_norm_eps
  49. self.hidden_act = hidden_act
  50. @classmethod
  51. def from_pretrained(cls, pretrained_model_name_or_path: Union[str,
  52. os.PathLike],
  53. **kwargs) -> "PretrainedConfig":
  54. cls._set_token_in_kwargs(kwargs)
  55. config_dict, kwargs = cls.get_config_dict(
  56. pretrained_model_name_or_path, **kwargs)
  57. # get the vision config dict if we are loading from SiglipConfig
  58. if config_dict.get("model_type") == "siglip":
  59. config_dict = config_dict["vision_config"]
  60. if "model_type" in config_dict and hasattr(
  61. cls,
  62. "model_type") and config_dict["model_type"] != cls.model_type:
  63. logger.warning(
  64. f"You are using a model of type {config_dict['model_type']} to"
  65. f" instantiate a model of type {cls.model_type}. "
  66. "This is not supported for all configurations of models and "
  67. "can yield errors.")
  68. return cls.from_dict(config_dict, **kwargs)
  69. _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
  70. SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
  71. "google/siglip-base-patch16-224",
  72. # See all SigLIP models at https://huggingface.co/models?filter=siglip
  73. ]
  74. if is_flash_attn_2_available():
  75. from flash_attn import flash_attn_func, flash_attn_varlen_func
  76. from flash_attn.bert_padding import pad_input # noqa
  77. from flash_attn.bert_padding import index_first_axis, unpad_input
  78. # Copied from transformers.models.llama.modeling_llama._get_unpad_data
  79. def _get_unpad_data(attention_mask):
  80. seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
  81. indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
  82. max_seqlen_in_batch = seqlens_in_batch.max().item()
  83. cu_seqlens = F.pad(
  84. torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
  85. return (
  86. indices,
  87. cu_seqlens,
  88. max_seqlen_in_batch,
  89. )
  90. def _trunc_normal_(tensor, mean, std, a, b):
  91. def norm_cdf(x):
  92. # Computes standard normal cumulative distribution function
  93. return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
  94. if (mean < a - 2 * std) or (mean > b + 2 * std):
  95. warnings.warn(
  96. "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
  97. "The distribution of values may be incorrect.",
  98. stacklevel=2,
  99. )
  100. # Values are generated by using a truncated uniform distribution and
  101. # then using the inverse CDF for the normal distribution.
  102. # Get upper and lower cdf values
  103. l_ = norm_cdf((a - mean) / std)
  104. u = norm_cdf((b - mean) / std)
  105. # Uniformly fill tensor with values from [l, u], then translate to
  106. # [2l-1, 2u-1].
  107. tensor.uniform_(2 * l_ - 1, 2 * u - 1)
  108. # Use inverse cdf transform for normal distribution to get truncated
  109. # standard normal
  110. if tensor.dtype in [torch.float16, torch.bfloat16]:
  111. # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
  112. og_dtype = tensor.dtype
  113. tensor = tensor.to(torch.float32)
  114. tensor.erfinv_()
  115. tensor = tensor.to(og_dtype)
  116. else:
  117. tensor.erfinv_()
  118. # Transform to proper mean, std
  119. tensor.mul_(std * math.sqrt(2.0))
  120. tensor.add_(mean)
  121. # Clamp to ensure it's in the proper range
  122. if tensor.dtype == torch.float16:
  123. # The `clamp_` op is not (yet?) defined in float16+cpu
  124. tensor = tensor.to(torch.float32)
  125. tensor.clamp_(min=a, max=b)
  126. tensor = tensor.to(torch.float16)
  127. else:
  128. tensor.clamp_(min=a, max=b)
  129. def trunc_normal_tf_(tensor: torch.Tensor,
  130. mean: float = 0.0,
  131. std: float = 1.0,
  132. a: float = -2.0,
  133. b: float = 2.0) -> torch.Tensor:
  134. with torch.no_grad():
  135. _trunc_normal_(tensor, 0, 1.0, a, b)
  136. tensor.mul_(std).add_(mean)
  137. def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
  138. fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
  139. if mode == "fan_in":
  140. denom = fan_in
  141. elif mode == "fan_out":
  142. denom = fan_out
  143. elif mode == "fan_avg":
  144. denom = (fan_in + fan_out) / 2
  145. variance = scale / denom
  146. if distribution == "truncated_normal":
  147. # constant is stddev of standard normal truncated to (-2, 2)
  148. trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
  149. elif distribution == "normal":
  150. with torch.no_grad():
  151. tensor.normal_(std=math.sqrt(variance))
  152. elif distribution == "uniform":
  153. bound = math.sqrt(3 * variance)
  154. with torch.no_grad():
  155. tensor.uniform_(-bound, bound)
  156. else:
  157. raise ValueError(f"invalid distribution {distribution}")
  158. def lecun_normal_(tensor):
  159. variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
  160. def default_flax_embed_init(tensor):
  161. variance_scaling_(tensor, mode="fan_in", distribution="normal")
  162. class SiglipVisionModelOutput(ModelOutput):
  163. image_embeds: Optional[torch.FloatTensor] = None
  164. last_hidden_state: torch.FloatTensor = None
  165. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  166. attentions: Optional[Tuple[torch.FloatTensor]] = None
  167. class SiglipVisionEmbeddings(nn.Module):
  168. def __init__(self, config: SiglipVisionConfig):
  169. super().__init__()
  170. self.config = config
  171. self.embed_dim = config.hidden_size
  172. self.image_size = config.image_size
  173. self.patch_size = config.patch_size
  174. self.patch_embedding = nn.Conv2d(
  175. in_channels=config.num_channels,
  176. out_channels=self.embed_dim,
  177. kernel_size=self.patch_size,
  178. stride=self.patch_size,
  179. padding="valid",
  180. )
  181. self.num_patches_per_side = self.image_size // self.patch_size
  182. self.num_patches = self.num_patches_per_side**2
  183. self.num_positions = self.num_patches
  184. self.position_embedding = nn.Embedding(self.num_positions,
  185. self.embed_dim)
  186. def forward(self,
  187. pixel_values: torch.FloatTensor,
  188. patch_attention_mask: torch.BoolTensor,
  189. tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor:
  190. batch_size = pixel_values.size(0)
  191. patch_embeds = self.patch_embedding(pixel_values)
  192. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  193. max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
  194. max_nb_patches_h, max_nb_patches_w = (max_im_h // self.patch_size,
  195. max_im_w // self.patch_size)
  196. boundaries = torch.arange(1 / self.num_patches_per_side, 1.0,
  197. 1 / self.num_patches_per_side)
  198. position_ids = torch.full(
  199. size=(
  200. batch_size,
  201. max_nb_patches_h * max_nb_patches_w,
  202. ),
  203. fill_value=0,
  204. )
  205. for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
  206. if tgt_sizes is not None:
  207. nb_patches_h = tgt_sizes[batch_idx][0]
  208. nb_patches_w = tgt_sizes[batch_idx][1]
  209. else:
  210. nb_patches_h = p_attn_mask[:, 0].sum()
  211. nb_patches_w = p_attn_mask[0].sum()
  212. fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
  213. fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
  214. bucket_coords_h = torch.bucketize(fractional_coords_h,
  215. boundaries,
  216. right=True)
  217. bucket_coords_w = torch.bucketize(fractional_coords_w,
  218. boundaries,
  219. right=True)
  220. pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side +
  221. bucket_coords_w).flatten()
  222. position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
  223. position_ids = position_ids.to(self.position_embedding.weight.device)
  224. embeddings = embeddings + self.position_embedding(position_ids)
  225. return embeddings
  226. class SiglipAttention(nn.Module):
  227. """Multi-headed attention from 'Attention Is All You Need' paper"""
  228. # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
  229. def __init__(self, config):
  230. super().__init__()
  231. self.config = config
  232. self.embed_dim = config.hidden_size
  233. self.num_heads = config.num_attention_heads
  234. self.head_dim = self.embed_dim // self.num_heads
  235. if self.head_dim * self.num_heads != self.embed_dim:
  236. raise ValueError(
  237. "embed_dim must be divisible by num_heads (got `embed_dim`: "
  238. f"{self.embed_dim} and `num_heads`:"
  239. f" {self.num_heads}).")
  240. self.scale = self.head_dim**-0.5
  241. self.dropout = config.attention_dropout
  242. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  243. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  244. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  245. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  246. def forward(
  247. self,
  248. hidden_states: torch.Tensor,
  249. attention_mask: Optional[torch.Tensor] = None,
  250. output_attentions: Optional[bool] = False,
  251. ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
  252. Optional[Tuple[torch.Tensor]]]:
  253. """Input shape: Batch x Time x Channel"""
  254. batch_size, q_len, _ = hidden_states.size()
  255. query_states = self.q_proj(hidden_states)
  256. key_states = self.k_proj(hidden_states)
  257. value_states = self.v_proj(hidden_states)
  258. query_states = query_states.view(batch_size, q_len, self.num_heads,
  259. self.head_dim).transpose(1, 2)
  260. key_states = key_states.view(batch_size, q_len, self.num_heads,
  261. self.head_dim).transpose(1, 2)
  262. value_states = value_states.view(batch_size, q_len, self.num_heads,
  263. self.head_dim).transpose(1, 2)
  264. k_v_seq_len = key_states.shape[-2]
  265. attn_weights = torch.matmul(query_states, key_states.transpose(
  266. 2, 3)) * self.scale
  267. if attn_weights.size() != (batch_size, self.num_heads, q_len,
  268. k_v_seq_len):
  269. raise ValueError(
  270. "Attention weights should be of size "
  271. f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
  272. f" {attn_weights.size()}")
  273. if attention_mask is not None:
  274. if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
  275. raise ValueError(
  276. "Attention mask should be of size "
  277. f"{(batch_size, 1, q_len, k_v_seq_len)}",
  278. f"but is {attention_mask.size()}")
  279. attn_weights = attn_weights + attention_mask
  280. # upcast attention to fp32
  281. attn_weights = nn.functional.softmax(attn_weights,
  282. dim=-1,
  283. dtype=torch.float32).to(
  284. query_states.dtype)
  285. attn_weights = nn.functional.dropout(attn_weights,
  286. p=self.dropout,
  287. training=self.training)
  288. attn_output = torch.matmul(attn_weights, value_states)
  289. if attn_output.size() != (batch_size, self.num_heads, q_len,
  290. self.head_dim):
  291. raise ValueError(
  292. "`attn_output` should be of size "
  293. f"{(batch_size, self.num_heads, q_len, self.head_dim)}, "
  294. "but is"
  295. f" {attn_output.size()}")
  296. attn_output = attn_output.transpose(1, 2).contiguous()
  297. attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
  298. attn_output = self.out_proj(attn_output)
  299. return attn_output, attn_weights
  300. class SiglipFlashAttention2(SiglipAttention):
  301. def __init__(self, *args, **kwargs):
  302. super().__init__(*args, **kwargs)
  303. self.is_causal = False # Hack to make sure we don't use a causal mask
  304. def forward(
  305. self,
  306. hidden_states: torch.Tensor,
  307. attention_mask: Optional[torch.LongTensor] = None,
  308. position_ids: Optional[torch.LongTensor] = None,
  309. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  310. output_attentions: bool = False,
  311. use_cache: bool = False,
  312. **kwargs,
  313. ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
  314. Optional[Tuple[torch.Tensor]]]:
  315. output_attentions = False
  316. bsz, q_len, _ = hidden_states.size()
  317. query_states = self.q_proj(hidden_states)
  318. key_states = self.k_proj(hidden_states)
  319. value_states = self.v_proj(hidden_states)
  320. query_states = query_states.view(bsz, q_len, self.num_heads,
  321. self.head_dim).transpose(1, 2)
  322. key_states = key_states.view(bsz, q_len, self.num_heads,
  323. self.head_dim).transpose(1, 2)
  324. value_states = value_states.view(bsz, q_len, self.num_heads,
  325. self.head_dim).transpose(1, 2)
  326. kv_seq_len = key_states.shape[-2]
  327. if past_key_value is not None:
  328. kv_seq_len += past_key_value.get_usable_length(
  329. kv_seq_len, self.layer_idx)
  330. query_states = query_states.transpose(1, 2)
  331. key_states = key_states.transpose(1, 2)
  332. value_states = value_states.transpose(1, 2)
  333. dropout_rate = self.dropout if self.training else 0.0
  334. input_dtype = query_states.dtype
  335. if input_dtype == torch.float32:
  336. if torch.is_autocast_enabled():
  337. target_dtype = torch.get_autocast_gpu_dtype()
  338. # Handle the case where the model is quantized
  339. elif hasattr(self.config, "_pre_quantization_dtype"):
  340. target_dtype = self.config._pre_quantization_dtype
  341. else:
  342. target_dtype = self.q_proj.weight.dtype
  343. logger.warning(
  344. "The input hidden states seems to be "
  345. "silently casted in float32, "
  346. "this might be related to the fact "
  347. "you have upcasted embedding or layer norm layers in float32. "
  348. "We will cast back the input in"
  349. f" {target_dtype}.")
  350. query_states = query_states.to(target_dtype)
  351. key_states = key_states.to(target_dtype)
  352. value_states = value_states.to(target_dtype)
  353. attn_output = self._flash_attention_forward(query_states,
  354. key_states,
  355. value_states,
  356. attention_mask,
  357. q_len,
  358. dropout=dropout_rate)
  359. attn_output = attn_output.reshape(bsz, q_len,
  360. self.embed_dim).contiguous()
  361. attn_output = self.out_proj(attn_output)
  362. if not output_attentions:
  363. attn_weights = None
  364. return attn_output, attn_weights
  365. def _flash_attention_forward(self,
  366. query_states,
  367. key_states,
  368. value_states,
  369. attention_mask,
  370. query_length,
  371. dropout=0.0,
  372. softmax_scale=None):
  373. causal = self.is_causal and query_length != 1
  374. # Contains at least one padding token in the sequence
  375. if attention_mask is not None:
  376. batch_size = query_states.shape[0]
  377. (query_states, key_states, value_states, indices_q, cu_seq_lens,
  378. max_seq_lens) = self._upad_input(query_states, key_states,
  379. value_states, attention_mask,
  380. query_length)
  381. cu_seqlens_q, cu_seqlens_k = cu_seq_lens
  382. max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
  383. attn_output_unpad = flash_attn_varlen_func(
  384. query_states,
  385. key_states,
  386. value_states,
  387. cu_seqlens_q=cu_seqlens_q,
  388. cu_seqlens_k=cu_seqlens_k,
  389. max_seqlen_q=max_seqlen_in_batch_q,
  390. max_seqlen_k=max_seqlen_in_batch_k,
  391. dropout_p=dropout,
  392. softmax_scale=softmax_scale,
  393. causal=causal,
  394. )
  395. attn_output = pad_input(attn_output_unpad, indices_q, batch_size,
  396. query_length)
  397. else:
  398. attn_output = flash_attn_func(query_states,
  399. key_states,
  400. value_states,
  401. dropout,
  402. softmax_scale=softmax_scale,
  403. causal=causal)
  404. return attn_output
  405. def _upad_input(self, query_layer, key_layer, value_layer, attention_mask,
  406. query_length):
  407. indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
  408. attention_mask)
  409. batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
  410. key_layer = index_first_axis(
  411. key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
  412. head_dim), indices_k)
  413. value_layer = index_first_axis(
  414. value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
  415. head_dim), indices_k)
  416. if query_length == kv_seq_len:
  417. query_layer = index_first_axis(
  418. query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
  419. head_dim), indices_k)
  420. cu_seqlens_q = cu_seqlens_k
  421. max_seqlen_in_batch_q = max_seqlen_in_batch_k
  422. indices_q = indices_k
  423. elif query_length == 1:
  424. max_seqlen_in_batch_q = 1
  425. cu_seqlens_q = torch.arange(
  426. batch_size + 1, dtype=torch.int32, device=query_layer.device
  427. ) # There is a memcpy here, that is very bad.
  428. indices_q = cu_seqlens_q[:-1]
  429. query_layer = query_layer.squeeze(1)
  430. else:
  431. # The -q_len: slice assumes left padding.
  432. attention_mask = attention_mask[:, -query_length:]
  433. (query_layer, indices_q, cu_seqlens_q,
  434. max_seqlen_in_batch_q) = unpad_input(query_layer, attention_mask)
  435. return (
  436. query_layer,
  437. key_layer,
  438. value_layer,
  439. indices_q,
  440. (cu_seqlens_q, cu_seqlens_k),
  441. (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
  442. )
  443. # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
  444. class SiglipMLP(nn.Module):
  445. def __init__(self, config):
  446. super().__init__()
  447. self.config = config
  448. self.activation_fn = ACT2FN[config.hidden_act]
  449. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  450. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  451. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  452. hidden_states = self.fc1(hidden_states)
  453. hidden_states = self.activation_fn(hidden_states)
  454. hidden_states = self.fc2(hidden_states)
  455. return hidden_states
  456. # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer
  457. # with CLIP->Siglip
  458. class SiglipEncoderLayer(nn.Module):
  459. def __init__(self, config: SiglipVisionConfig):
  460. super().__init__()
  461. self.embed_dim = config.hidden_size
  462. self._use_flash_attention_2 = (
  463. config._attn_implementation == "flash_attention_2")
  464. self.self_attn = (SiglipAttention(config)
  465. if not self._use_flash_attention_2 else
  466. SiglipFlashAttention2(config))
  467. self.layer_norm1 = nn.LayerNorm(self.embed_dim,
  468. eps=config.layer_norm_eps)
  469. self.mlp = SiglipMLP(config)
  470. self.layer_norm2 = nn.LayerNorm(self.embed_dim,
  471. eps=config.layer_norm_eps)
  472. def forward(
  473. self,
  474. hidden_states: torch.Tensor,
  475. attention_mask: torch.Tensor,
  476. output_attentions: Optional[bool] = False,
  477. ) -> Tuple[torch.FloatTensor]:
  478. residual = hidden_states
  479. hidden_states = self.layer_norm1(hidden_states)
  480. hidden_states, attn_weights = self.self_attn(
  481. hidden_states=hidden_states,
  482. attention_mask=attention_mask,
  483. output_attentions=output_attentions,
  484. )
  485. hidden_states = residual + hidden_states
  486. residual = hidden_states
  487. hidden_states = self.layer_norm2(hidden_states)
  488. hidden_states = self.mlp(hidden_states)
  489. hidden_states = residual + hidden_states
  490. outputs = (hidden_states, )
  491. if output_attentions:
  492. outputs += (attn_weights, )
  493. return outputs
  494. class SiglipPreTrainedModel(PreTrainedModel):
  495. config_class = SiglipVisionConfig
  496. base_model_prefix = "siglip"
  497. supports_gradient_checkpointing = True
  498. def _init_weights(self, module):
  499. """Initialize the weights"""
  500. if isinstance(module, SiglipVisionEmbeddings):
  501. width = self.config.hidden_size
  502. nn.init.normal_(module.position_embedding.weight,
  503. std=1 / np.sqrt(width))
  504. elif isinstance(module, nn.Embedding):
  505. default_flax_embed_init(module.weight)
  506. elif isinstance(module, SiglipAttention):
  507. nn.init.normal_(module.q_proj.weight)
  508. nn.init.normal_(module.k_proj.weight)
  509. nn.init.normal_(module.v_proj.weight)
  510. nn.init.normal_(module.out_proj.weight)
  511. nn.init.zeros_(module.q_proj.bias)
  512. nn.init.zeros_(module.k_proj.bias)
  513. nn.init.zeros_(module.v_proj.bias)
  514. nn.init.zeros_(module.out_proj.bias)
  515. elif isinstance(module, SiglipMLP):
  516. nn.init.normal_(module.fc1.weight)
  517. nn.init.normal_(module.fc2.weight)
  518. nn.init.normal_(module.fc1.bias, std=1e-6)
  519. nn.init.normal_(module.fc2.bias, std=1e-6)
  520. elif isinstance(module, (nn.Linear, nn.Conv2d)):
  521. lecun_normal_(module.weight)
  522. if module.bias is not None:
  523. nn.init.zeros_(module.bias)
  524. elif isinstance(module, nn.LayerNorm):
  525. module.bias.data.zero_()
  526. module.weight.data.fill_(1.0)
  527. # Copied from transformers.models.clip.modeling_clip.CLIPEncoder
  528. # with CLIP->Siglip
  529. class SiglipEncoder(nn.Module):
  530. def __init__(self, config: SiglipVisionConfig):
  531. super().__init__()
  532. self.config = config
  533. self.layers = nn.ModuleList([
  534. SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)
  535. ])
  536. self.gradient_checkpointing = False
  537. # Ignore copy
  538. def forward(
  539. self,
  540. inputs_embeds,
  541. attention_mask: Optional[torch.Tensor] = None,
  542. output_attentions: Optional[bool] = None,
  543. output_hidden_states: Optional[bool] = None,
  544. return_dict: Optional[bool] = None,
  545. ) -> Union[Tuple, BaseModelOutput]:
  546. output_attentions = output_attentions if output_attentions is not None \
  547. else self.config.output_attentions
  548. output_hidden_states = (output_hidden_states
  549. if output_hidden_states is not None else
  550. self.config.output_hidden_states)
  551. return_dict = return_dict if return_dict is not None \
  552. else self.config.use_return_dict
  553. encoder_states = () if output_hidden_states else None
  554. all_attentions = () if output_attentions else None
  555. hidden_states = inputs_embeds
  556. for encoder_layer in self.layers:
  557. if output_hidden_states:
  558. encoder_states = encoder_states + (hidden_states, )
  559. if self.gradient_checkpointing and self.training:
  560. layer_outputs = self._gradient_checkpointing_func(
  561. encoder_layer.__call__,
  562. hidden_states,
  563. attention_mask,
  564. output_attentions,
  565. )
  566. else:
  567. layer_outputs = encoder_layer(
  568. hidden_states,
  569. attention_mask,
  570. output_attentions=output_attentions,
  571. )
  572. hidden_states = layer_outputs[0]
  573. if output_attentions:
  574. all_attentions = all_attentions + (layer_outputs[1], )
  575. if output_hidden_states:
  576. encoder_states = encoder_states + (hidden_states, )
  577. if not return_dict:
  578. return tuple(
  579. v for v in [hidden_states, encoder_states, all_attentions]
  580. if v is not None)
  581. return BaseModelOutput(last_hidden_state=hidden_states,
  582. hidden_states=encoder_states,
  583. attentions=all_attentions)
  584. class SiglipVisionTransformer(SiglipPreTrainedModel):
  585. config_class = SiglipVisionConfig
  586. main_input_name = "pixel_values"
  587. _supports_flash_attn_2 = True
  588. def __init__(self, config: SiglipVisionConfig):
  589. super().__init__(config)
  590. self.config = config
  591. embed_dim = config.hidden_size
  592. self.embeddings = SiglipVisionEmbeddings(config)
  593. self.encoder = SiglipEncoder(config)
  594. self.post_layernorm = nn.LayerNorm(embed_dim,
  595. eps=config.layer_norm_eps)
  596. self._use_flash_attention_2 = (
  597. config._attn_implementation == "flash_attention_2")
  598. # Initialize weights and apply final processing
  599. self.post_init()
  600. def get_input_embeddings(self) -> nn.Module:
  601. return self.embeddings.patch_embedding
  602. @replace_return_docstrings(output_type=BaseModelOutputWithPooling,
  603. config_class=SiglipVisionConfig)
  604. def forward(
  605. self,
  606. pixel_values,
  607. patch_attention_mask: Optional[torch.BoolTensor] = None,
  608. tgt_sizes: Optional[torch.IntTensor] = None,
  609. output_attentions: Optional[bool] = None,
  610. output_hidden_states: Optional[bool] = None,
  611. return_dict: Optional[bool] = None,
  612. ) -> Union[Tuple, BaseModelOutputWithPooling]:
  613. r"""
  614. Returns:
  615. """
  616. output_attentions = output_attentions if output_attentions is not None \
  617. else self.config.output_attentions
  618. output_hidden_states = (output_hidden_states
  619. if output_hidden_states is not None else
  620. self.config.output_hidden_states)
  621. return_dict = return_dict if return_dict is not None \
  622. else self.config.use_return_dict
  623. batch_size = pixel_values.size(0)
  624. if patch_attention_mask is None:
  625. patch_attention_mask = torch.ones(
  626. size=(
  627. batch_size,
  628. pixel_values.size(2) // self.config.patch_size,
  629. pixel_values.size(3) // self.config.patch_size,
  630. ),
  631. dtype=torch.bool,
  632. device=pixel_values.device,
  633. )
  634. hidden_states = self.embeddings(
  635. pixel_values=pixel_values,
  636. patch_attention_mask=patch_attention_mask,
  637. tgt_sizes=tgt_sizes)
  638. patch_attention_mask = patch_attention_mask.view(batch_size, -1)
  639. # The call to `_upad_input` in `_flash_attention_forward` is expensive
  640. # So when the `patch_attention_mask` is full of 1s
  641. # (i.e. attending to the whole sequence),
  642. # avoiding passing the attention_mask,
  643. # which is equivalent to attending to the full sequence
  644. if not torch.any(~patch_attention_mask):
  645. attention_mask = None
  646. else:
  647. attention_mask = (_prepare_4d_attention_mask(
  648. patch_attention_mask, hidden_states.dtype)
  649. if not self._use_flash_attention_2 else
  650. patch_attention_mask)
  651. encoder_outputs = self.encoder(
  652. inputs_embeds=hidden_states,
  653. attention_mask=attention_mask,
  654. output_attentions=output_attentions,
  655. output_hidden_states=output_hidden_states,
  656. return_dict=return_dict,
  657. )
  658. last_hidden_state = encoder_outputs[0]
  659. last_hidden_state = self.post_layernorm(last_hidden_state)
  660. if not return_dict:
  661. return (last_hidden_state, None) + encoder_outputs[1:]
  662. return BaseModelOutputWithPooling(
  663. last_hidden_state=last_hidden_state,
  664. pooler_output=None,
  665. hidden_states=encoder_outputs.hidden_states,
  666. attentions=encoder_outputs.attentions,
  667. )