1
0

na_vit.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804
  1. import logging
  2. import math
  3. import os
  4. import warnings
  5. from typing import Optional, Tuple, Union
  6. import numpy as np
  7. import torch
  8. import torch.nn.functional as F
  9. from torch import nn
  10. from torch.nn.init import _calculate_fan_in_and_fan_out
  11. from transformers.activations import ACT2FN
  12. from transformers.configuration_utils import PretrainedConfig
  13. from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
  14. from transformers.modeling_outputs import (BaseModelOutput,
  15. BaseModelOutputWithPooling)
  16. from transformers.modeling_utils import PreTrainedModel
  17. from transformers.utils import (ModelOutput, is_flash_attn_2_available,
  18. replace_return_docstrings)
  19. logger = logging.getLogger("aphrodite")
  20. # For Siglip: copied from
  21. # HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes
  22. # Remove hints as there's little possibility to change these code.
  23. class SiglipVisionConfig(PretrainedConfig):
  24. model_type = "siglip_vision_model"
  25. def __init__(
  26. self,
  27. hidden_size=768,
  28. intermediate_size=3072,
  29. num_hidden_layers=12,
  30. num_attention_heads=12,
  31. num_channels=3,
  32. image_size=224,
  33. patch_size=16,
  34. hidden_act="gelu_pytorch_tanh",
  35. layer_norm_eps=1e-6,
  36. attention_dropout=0.0,
  37. **kwargs,
  38. ):
  39. super().__init__(**kwargs)
  40. self.hidden_size = hidden_size
  41. self.intermediate_size = intermediate_size
  42. self.num_hidden_layers = num_hidden_layers
  43. self.num_attention_heads = num_attention_heads
  44. self.num_channels = num_channels
  45. self.patch_size = patch_size
  46. self.image_size = image_size
  47. self.attention_dropout = attention_dropout
  48. self.layer_norm_eps = layer_norm_eps
  49. self.hidden_act = hidden_act
  50. @classmethod
  51. def from_pretrained(cls, pretrained_model_name_or_path: Union[str,
  52. os.PathLike],
  53. **kwargs) -> "PretrainedConfig":
  54. cls._set_token_in_kwargs(kwargs)
  55. config_dict, kwargs = cls.get_config_dict(
  56. pretrained_model_name_or_path, **kwargs)
  57. # get the vision config dict if we are loading from SiglipConfig
  58. if config_dict.get("model_type") == "siglip":
  59. config_dict = config_dict["vision_config"]
  60. if "model_type" in config_dict and hasattr(
  61. cls,
  62. "model_type") and config_dict["model_type"] != cls.model_type:
  63. logger.warning(
  64. "You are using a model of type %s to "
  65. "instantiate a model of type %s. "
  66. "This is not supported for all configurations"
  67. "of models and can yield errors.", config_dict['model_type'],
  68. cls.model_type)
  69. return cls.from_dict(config_dict, **kwargs)
  70. _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
  71. SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
  72. "google/siglip-base-patch16-224",
  73. # See all SigLIP models at https://huggingface.co/models?filter=siglip
  74. ]
  75. if is_flash_attn_2_available():
  76. from flash_attn import flash_attn_func, flash_attn_varlen_func
  77. from flash_attn.bert_padding import pad_input # noqa
  78. from flash_attn.bert_padding import index_first_axis, unpad_input
  79. # Copied from transformers.models.llama.modeling_llama._get_unpad_data
  80. def _get_unpad_data(attention_mask):
  81. seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
  82. indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
  83. max_seqlen_in_batch = seqlens_in_batch.max().item()
  84. cu_seqlens = F.pad(
  85. torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
  86. return (
  87. indices,
  88. cu_seqlens,
  89. max_seqlen_in_batch,
  90. )
  91. def _trunc_normal_(tensor, mean, std, a, b):
  92. def norm_cdf(x):
  93. # Computes standard normal cumulative distribution function
  94. return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
  95. if (mean < a - 2 * std) or (mean > b + 2 * std):
  96. warnings.warn(
  97. "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
  98. "The distribution of values may be incorrect.",
  99. stacklevel=2,
  100. )
  101. # Values are generated by using a truncated uniform distribution and
  102. # then using the inverse CDF for the normal distribution.
  103. # Get upper and lower cdf values
  104. l_ = norm_cdf((a - mean) / std)
  105. u = norm_cdf((b - mean) / std)
  106. # Uniformly fill tensor with values from [l, u], then translate to
  107. # [2l-1, 2u-1].
  108. tensor.uniform_(2 * l_ - 1, 2 * u - 1)
  109. # Use inverse cdf transform for normal distribution to get truncated
  110. # standard normal
  111. if tensor.dtype in [torch.float16, torch.bfloat16]:
  112. # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
  113. og_dtype = tensor.dtype
  114. tensor = tensor.to(torch.float32)
  115. tensor.erfinv_()
  116. tensor = tensor.to(og_dtype)
  117. else:
  118. tensor.erfinv_()
  119. # Transform to proper mean, std
  120. tensor.mul_(std * math.sqrt(2.0))
  121. tensor.add_(mean)
  122. # Clamp to ensure it's in the proper range
  123. if tensor.dtype == torch.float16:
  124. # The `clamp_` op is not (yet?) defined in float16+cpu
  125. tensor = tensor.to(torch.float32)
  126. tensor.clamp_(min=a, max=b)
  127. tensor = tensor.to(torch.float16)
  128. else:
  129. tensor.clamp_(min=a, max=b)
  130. def trunc_normal_tf_(tensor: torch.Tensor,
  131. mean: float = 0.0,
  132. std: float = 1.0,
  133. a: float = -2.0,
  134. b: float = 2.0) -> torch.Tensor:
  135. with torch.no_grad():
  136. _trunc_normal_(tensor, 0, 1.0, a, b)
  137. tensor.mul_(std).add_(mean)
  138. def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
  139. fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
  140. if mode == "fan_in":
  141. denom = fan_in
  142. elif mode == "fan_out":
  143. denom = fan_out
  144. elif mode == "fan_avg":
  145. denom = (fan_in + fan_out) / 2
  146. variance = scale / denom
  147. if distribution == "truncated_normal":
  148. # constant is stddev of standard normal truncated to (-2, 2)
  149. trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
  150. elif distribution == "normal":
  151. with torch.no_grad():
  152. tensor.normal_(std=math.sqrt(variance))
  153. elif distribution == "uniform":
  154. bound = math.sqrt(3 * variance)
  155. with torch.no_grad():
  156. tensor.uniform_(-bound, bound)
  157. else:
  158. raise ValueError(f"invalid distribution {distribution}")
  159. def lecun_normal_(tensor):
  160. variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
  161. def default_flax_embed_init(tensor):
  162. variance_scaling_(tensor, mode="fan_in", distribution="normal")
  163. class SiglipVisionModelOutput(ModelOutput):
  164. image_embeds: Optional[torch.FloatTensor] = None
  165. last_hidden_state: torch.FloatTensor = None
  166. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  167. attentions: Optional[Tuple[torch.FloatTensor]] = None
  168. class SiglipVisionEmbeddings(nn.Module):
  169. def __init__(self, config: SiglipVisionConfig):
  170. super().__init__()
  171. self.config = config
  172. self.embed_dim = config.hidden_size
  173. self.image_size = config.image_size
  174. self.patch_size = config.patch_size
  175. self.patch_embedding = nn.Conv2d(
  176. in_channels=config.num_channels,
  177. out_channels=self.embed_dim,
  178. kernel_size=self.patch_size,
  179. stride=self.patch_size,
  180. padding="valid",
  181. )
  182. self.num_patches_per_side = self.image_size // self.patch_size
  183. self.num_patches = self.num_patches_per_side**2
  184. self.num_positions = self.num_patches
  185. self.position_embedding = nn.Embedding(self.num_positions,
  186. self.embed_dim)
  187. def forward(self,
  188. pixel_values: torch.FloatTensor,
  189. patch_attention_mask: torch.BoolTensor,
  190. tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor:
  191. batch_size = pixel_values.size(0)
  192. patch_embeds = self.patch_embedding(pixel_values)
  193. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  194. max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
  195. max_nb_patches_h, max_nb_patches_w = (max_im_h // self.patch_size,
  196. max_im_w // self.patch_size)
  197. boundaries = torch.arange(1 / self.num_patches_per_side, 1.0,
  198. 1 / self.num_patches_per_side)
  199. position_ids = torch.full(
  200. size=(
  201. batch_size,
  202. max_nb_patches_h * max_nb_patches_w,
  203. ),
  204. fill_value=0,
  205. )
  206. for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
  207. if tgt_sizes is not None:
  208. nb_patches_h = tgt_sizes[batch_idx][0]
  209. nb_patches_w = tgt_sizes[batch_idx][1]
  210. else:
  211. nb_patches_h = p_attn_mask[:, 0].sum()
  212. nb_patches_w = p_attn_mask[0].sum()
  213. fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
  214. fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
  215. bucket_coords_h = torch.bucketize(fractional_coords_h,
  216. boundaries,
  217. right=True)
  218. bucket_coords_w = torch.bucketize(fractional_coords_w,
  219. boundaries,
  220. right=True)
  221. pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side +
  222. bucket_coords_w).flatten()
  223. position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
  224. position_ids = position_ids.to(self.position_embedding.weight.device)
  225. embeddings = embeddings + self.position_embedding(position_ids)
  226. return embeddings
  227. class SiglipAttention(nn.Module):
  228. """Multi-headed attention from 'Attention Is All You Need' paper"""
  229. # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
  230. def __init__(self, config):
  231. super().__init__()
  232. self.config = config
  233. self.embed_dim = config.hidden_size
  234. self.num_heads = config.num_attention_heads
  235. self.head_dim = self.embed_dim // self.num_heads
  236. if self.head_dim * self.num_heads != self.embed_dim:
  237. raise ValueError(
  238. "embed_dim must be divisible by num_heads (got `embed_dim`: "
  239. f"{self.embed_dim} and `num_heads`:"
  240. f" {self.num_heads}).")
  241. self.scale = self.head_dim**-0.5
  242. self.dropout = config.attention_dropout
  243. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  244. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  245. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  246. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  247. def forward(
  248. self,
  249. hidden_states: torch.Tensor,
  250. attention_mask: Optional[torch.Tensor] = None,
  251. output_attentions: Optional[bool] = False,
  252. ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
  253. Optional[Tuple[torch.Tensor]]]:
  254. """Input shape: Batch x Time x Channel"""
  255. batch_size, q_len, _ = hidden_states.size()
  256. query_states = self.q_proj(hidden_states)
  257. key_states = self.k_proj(hidden_states)
  258. value_states = self.v_proj(hidden_states)
  259. query_states = query_states.view(batch_size, q_len, self.num_heads,
  260. self.head_dim).transpose(1, 2)
  261. key_states = key_states.view(batch_size, q_len, self.num_heads,
  262. self.head_dim).transpose(1, 2)
  263. value_states = value_states.view(batch_size, q_len, self.num_heads,
  264. self.head_dim).transpose(1, 2)
  265. k_v_seq_len = key_states.shape[-2]
  266. attn_weights = torch.matmul(query_states, key_states.transpose(
  267. 2, 3)) * self.scale
  268. if attn_weights.size() != (batch_size, self.num_heads, q_len,
  269. k_v_seq_len):
  270. raise ValueError(
  271. "Attention weights should be of size "
  272. f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
  273. f" {attn_weights.size()}")
  274. if attention_mask is not None:
  275. if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
  276. raise ValueError(
  277. "Attention mask should be of size "
  278. f"{(batch_size, 1, q_len, k_v_seq_len)}",
  279. f"but is {attention_mask.size()}")
  280. attn_weights = attn_weights + attention_mask
  281. # upcast attention to fp32
  282. attn_weights = nn.functional.softmax(attn_weights,
  283. dim=-1,
  284. dtype=torch.float32).to(
  285. query_states.dtype)
  286. attn_weights = nn.functional.dropout(attn_weights,
  287. p=self.dropout,
  288. training=self.training)
  289. attn_output = torch.matmul(attn_weights, value_states)
  290. if attn_output.size() != (batch_size, self.num_heads, q_len,
  291. self.head_dim):
  292. raise ValueError(
  293. "`attn_output` should be of size "
  294. f"{(batch_size, self.num_heads, q_len, self.head_dim)}, "
  295. "but is"
  296. f" {attn_output.size()}")
  297. attn_output = attn_output.transpose(1, 2).contiguous()
  298. attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
  299. attn_output = self.out_proj(attn_output)
  300. return attn_output, attn_weights
  301. class SiglipFlashAttention2(SiglipAttention):
  302. def __init__(self, *args, **kwargs):
  303. super().__init__(*args, **kwargs)
  304. self.is_causal = False # Hack to make sure we don't use a causal mask
  305. def forward(
  306. self,
  307. hidden_states: torch.Tensor,
  308. attention_mask: Optional[torch.LongTensor] = None,
  309. position_ids: Optional[torch.LongTensor] = None,
  310. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  311. output_attentions: bool = False,
  312. use_cache: bool = False,
  313. **kwargs,
  314. ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
  315. Optional[Tuple[torch.Tensor]]]:
  316. output_attentions = False
  317. bsz, q_len, _ = hidden_states.size()
  318. query_states = self.q_proj(hidden_states)
  319. key_states = self.k_proj(hidden_states)
  320. value_states = self.v_proj(hidden_states)
  321. query_states = query_states.view(bsz, q_len, self.num_heads,
  322. self.head_dim).transpose(1, 2)
  323. key_states = key_states.view(bsz, q_len, self.num_heads,
  324. self.head_dim).transpose(1, 2)
  325. value_states = value_states.view(bsz, q_len, self.num_heads,
  326. self.head_dim).transpose(1, 2)
  327. kv_seq_len = key_states.shape[-2]
  328. if past_key_value is not None:
  329. kv_seq_len += past_key_value.get_usable_length(
  330. kv_seq_len, self.layer_idx)
  331. query_states = query_states.transpose(1, 2)
  332. key_states = key_states.transpose(1, 2)
  333. value_states = value_states.transpose(1, 2)
  334. dropout_rate = self.dropout if self.training else 0.0
  335. input_dtype = query_states.dtype
  336. if input_dtype == torch.float32:
  337. if torch.is_autocast_enabled():
  338. target_dtype = torch.get_autocast_gpu_dtype()
  339. # Handle the case where the model is quantized
  340. elif hasattr(self.config, "_pre_quantization_dtype"):
  341. target_dtype = self.config._pre_quantization_dtype
  342. else:
  343. target_dtype = self.q_proj.weight.dtype
  344. logger.warning(
  345. "The input hidden states seems to be "
  346. "silently casted in float32, "
  347. "this might be related to the fact "
  348. "you have upcasted embedding or layer norm layers in float32. "
  349. "We will cast back the input in"
  350. " %s.", target_dtype)
  351. query_states = query_states.to(target_dtype)
  352. key_states = key_states.to(target_dtype)
  353. value_states = value_states.to(target_dtype)
  354. attn_output = self._flash_attention_forward(query_states,
  355. key_states,
  356. value_states,
  357. attention_mask,
  358. q_len,
  359. dropout=dropout_rate)
  360. attn_output = attn_output.reshape(bsz, q_len,
  361. self.embed_dim).contiguous()
  362. attn_output = self.out_proj(attn_output)
  363. if not output_attentions:
  364. attn_weights = None
  365. return attn_output, attn_weights
  366. def _flash_attention_forward(self,
  367. query_states,
  368. key_states,
  369. value_states,
  370. attention_mask,
  371. query_length,
  372. dropout=0.0,
  373. softmax_scale=None):
  374. causal = self.is_causal and query_length != 1
  375. # Contains at least one padding token in the sequence
  376. if attention_mask is not None:
  377. batch_size = query_states.shape[0]
  378. (query_states, key_states, value_states, indices_q, cu_seq_lens,
  379. max_seq_lens) = self._upad_input(query_states, key_states,
  380. value_states, attention_mask,
  381. query_length)
  382. cu_seqlens_q, cu_seqlens_k = cu_seq_lens
  383. max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
  384. attn_output_unpad = flash_attn_varlen_func(
  385. query_states,
  386. key_states,
  387. value_states,
  388. cu_seqlens_q=cu_seqlens_q,
  389. cu_seqlens_k=cu_seqlens_k,
  390. max_seqlen_q=max_seqlen_in_batch_q,
  391. max_seqlen_k=max_seqlen_in_batch_k,
  392. dropout_p=dropout,
  393. softmax_scale=softmax_scale,
  394. causal=causal,
  395. )
  396. attn_output = pad_input(attn_output_unpad, indices_q, batch_size,
  397. query_length)
  398. else:
  399. attn_output = flash_attn_func(query_states,
  400. key_states,
  401. value_states,
  402. dropout,
  403. softmax_scale=softmax_scale,
  404. causal=causal)
  405. return attn_output
  406. def _upad_input(self, query_layer, key_layer, value_layer, attention_mask,
  407. query_length):
  408. indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
  409. attention_mask)
  410. batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
  411. key_layer = index_first_axis(
  412. key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
  413. head_dim), indices_k)
  414. value_layer = index_first_axis(
  415. value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
  416. head_dim), indices_k)
  417. if query_length == kv_seq_len:
  418. query_layer = index_first_axis(
  419. query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
  420. head_dim), indices_k)
  421. cu_seqlens_q = cu_seqlens_k
  422. max_seqlen_in_batch_q = max_seqlen_in_batch_k
  423. indices_q = indices_k
  424. elif query_length == 1:
  425. max_seqlen_in_batch_q = 1
  426. cu_seqlens_q = torch.arange(
  427. batch_size + 1, dtype=torch.int32, device=query_layer.device
  428. ) # There is a memcpy here, that is very bad.
  429. indices_q = cu_seqlens_q[:-1]
  430. query_layer = query_layer.squeeze(1)
  431. else:
  432. # The -q_len: slice assumes left padding.
  433. attention_mask = attention_mask[:, -query_length:]
  434. (query_layer, indices_q, cu_seqlens_q,
  435. max_seqlen_in_batch_q) = unpad_input(query_layer, attention_mask)
  436. return (
  437. query_layer,
  438. key_layer,
  439. value_layer,
  440. indices_q,
  441. (cu_seqlens_q, cu_seqlens_k),
  442. (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
  443. )
  444. # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
  445. class SiglipMLP(nn.Module):
  446. def __init__(self, config):
  447. super().__init__()
  448. self.config = config
  449. self.activation_fn = ACT2FN[config.hidden_act]
  450. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  451. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  452. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  453. hidden_states = self.fc1(hidden_states)
  454. hidden_states = self.activation_fn(hidden_states)
  455. hidden_states = self.fc2(hidden_states)
  456. return hidden_states
  457. # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer
  458. # with CLIP->Siglip
  459. class SiglipEncoderLayer(nn.Module):
  460. def __init__(self, config: SiglipVisionConfig):
  461. super().__init__()
  462. self.embed_dim = config.hidden_size
  463. self._use_flash_attention_2 = (
  464. config._attn_implementation == "flash_attention_2")
  465. self.self_attn = (SiglipAttention(config)
  466. if not self._use_flash_attention_2 else
  467. SiglipFlashAttention2(config))
  468. self.layer_norm1 = nn.LayerNorm(self.embed_dim,
  469. eps=config.layer_norm_eps)
  470. self.mlp = SiglipMLP(config)
  471. self.layer_norm2 = nn.LayerNorm(self.embed_dim,
  472. eps=config.layer_norm_eps)
  473. def forward(
  474. self,
  475. hidden_states: torch.Tensor,
  476. attention_mask: torch.Tensor,
  477. output_attentions: Optional[bool] = False,
  478. ) -> Tuple[torch.FloatTensor]:
  479. residual = hidden_states
  480. hidden_states = self.layer_norm1(hidden_states)
  481. hidden_states, attn_weights = self.self_attn(
  482. hidden_states=hidden_states,
  483. attention_mask=attention_mask,
  484. output_attentions=output_attentions,
  485. )
  486. hidden_states = residual + hidden_states
  487. residual = hidden_states
  488. hidden_states = self.layer_norm2(hidden_states)
  489. hidden_states = self.mlp(hidden_states)
  490. hidden_states = residual + hidden_states
  491. outputs = (hidden_states, )
  492. if output_attentions:
  493. outputs += (attn_weights, )
  494. return outputs
  495. class SiglipPreTrainedModel(PreTrainedModel):
  496. config_class = SiglipVisionConfig
  497. base_model_prefix = "siglip"
  498. supports_gradient_checkpointing = True
  499. def _init_weights(self, module):
  500. """Initialize the weights"""
  501. if isinstance(module, SiglipVisionEmbeddings):
  502. width = self.config.hidden_size
  503. nn.init.normal_(module.position_embedding.weight,
  504. std=1 / np.sqrt(width))
  505. elif isinstance(module, nn.Embedding):
  506. default_flax_embed_init(module.weight)
  507. elif isinstance(module, SiglipAttention):
  508. nn.init.normal_(module.q_proj.weight)
  509. nn.init.normal_(module.k_proj.weight)
  510. nn.init.normal_(module.v_proj.weight)
  511. nn.init.normal_(module.out_proj.weight)
  512. nn.init.zeros_(module.q_proj.bias)
  513. nn.init.zeros_(module.k_proj.bias)
  514. nn.init.zeros_(module.v_proj.bias)
  515. nn.init.zeros_(module.out_proj.bias)
  516. elif isinstance(module, SiglipMLP):
  517. nn.init.normal_(module.fc1.weight)
  518. nn.init.normal_(module.fc2.weight)
  519. nn.init.normal_(module.fc1.bias, std=1e-6)
  520. nn.init.normal_(module.fc2.bias, std=1e-6)
  521. elif isinstance(module, (nn.Linear, nn.Conv2d)):
  522. lecun_normal_(module.weight)
  523. if module.bias is not None:
  524. nn.init.zeros_(module.bias)
  525. elif isinstance(module, nn.LayerNorm):
  526. module.bias.data.zero_()
  527. module.weight.data.fill_(1.0)
  528. # Copied from transformers.models.clip.modeling_clip.CLIPEncoder
  529. # with CLIP->Siglip
  530. class SiglipEncoder(nn.Module):
  531. def __init__(self, config: SiglipVisionConfig):
  532. super().__init__()
  533. self.config = config
  534. self.layers = nn.ModuleList([
  535. SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)
  536. ])
  537. self.gradient_checkpointing = False
  538. # Ignore copy
  539. def forward(
  540. self,
  541. inputs_embeds,
  542. attention_mask: Optional[torch.Tensor] = None,
  543. output_attentions: Optional[bool] = None,
  544. output_hidden_states: Optional[bool] = None,
  545. return_dict: Optional[bool] = None,
  546. ) -> Union[Tuple, BaseModelOutput]:
  547. output_attentions = output_attentions if output_attentions is not None \
  548. else self.config.output_attentions
  549. output_hidden_states = (output_hidden_states
  550. if output_hidden_states is not None else
  551. self.config.output_hidden_states)
  552. return_dict = return_dict if return_dict is not None \
  553. else self.config.use_return_dict
  554. encoder_states = () if output_hidden_states else None
  555. all_attentions = () if output_attentions else None
  556. hidden_states = inputs_embeds
  557. for encoder_layer in self.layers:
  558. if output_hidden_states:
  559. encoder_states = encoder_states + (hidden_states, )
  560. if self.gradient_checkpointing and self.training:
  561. layer_outputs = self._gradient_checkpointing_func(
  562. encoder_layer.__call__,
  563. hidden_states,
  564. attention_mask,
  565. output_attentions,
  566. )
  567. else:
  568. layer_outputs = encoder_layer(
  569. hidden_states,
  570. attention_mask,
  571. output_attentions=output_attentions,
  572. )
  573. hidden_states = layer_outputs[0]
  574. if output_attentions:
  575. all_attentions = all_attentions + (layer_outputs[1], )
  576. if output_hidden_states:
  577. encoder_states = encoder_states + (hidden_states, )
  578. if not return_dict:
  579. return tuple(
  580. v for v in [hidden_states, encoder_states, all_attentions]
  581. if v is not None)
  582. return BaseModelOutput(last_hidden_state=hidden_states,
  583. hidden_states=encoder_states,
  584. attentions=all_attentions)
  585. class SiglipVisionTransformer(SiglipPreTrainedModel):
  586. config_class = SiglipVisionConfig
  587. main_input_name = "pixel_values"
  588. _supports_flash_attn_2 = True
  589. def __init__(self, config: SiglipVisionConfig):
  590. super().__init__(config)
  591. self.config = config
  592. embed_dim = config.hidden_size
  593. self.embeddings = SiglipVisionEmbeddings(config)
  594. self.encoder = SiglipEncoder(config)
  595. self.post_layernorm = nn.LayerNorm(embed_dim,
  596. eps=config.layer_norm_eps)
  597. self._use_flash_attention_2 = (
  598. config._attn_implementation == "flash_attention_2")
  599. # Initialize weights and apply final processing
  600. self.post_init()
  601. def get_input_embeddings(self) -> nn.Module:
  602. return self.embeddings.patch_embedding
  603. @replace_return_docstrings(output_type=BaseModelOutputWithPooling,
  604. config_class=SiglipVisionConfig)
  605. def forward(
  606. self,
  607. pixel_values,
  608. patch_attention_mask: Optional[torch.BoolTensor] = None,
  609. tgt_sizes: Optional[torch.IntTensor] = None,
  610. output_attentions: Optional[bool] = None,
  611. output_hidden_states: Optional[bool] = None,
  612. return_dict: Optional[bool] = None,
  613. ) -> Union[Tuple, BaseModelOutputWithPooling]:
  614. r"""
  615. Returns:
  616. """
  617. output_attentions = output_attentions if output_attentions is not None \
  618. else self.config.output_attentions
  619. output_hidden_states = (output_hidden_states
  620. if output_hidden_states is not None else
  621. self.config.output_hidden_states)
  622. return_dict = return_dict if return_dict is not None \
  623. else self.config.use_return_dict
  624. batch_size = pixel_values.size(0)
  625. if patch_attention_mask is None:
  626. patch_attention_mask = torch.ones(
  627. size=(
  628. batch_size,
  629. pixel_values.size(2) // self.config.patch_size,
  630. pixel_values.size(3) // self.config.patch_size,
  631. ),
  632. dtype=torch.bool,
  633. device=pixel_values.device,
  634. )
  635. hidden_states = self.embeddings(
  636. pixel_values=pixel_values,
  637. patch_attention_mask=patch_attention_mask,
  638. tgt_sizes=tgt_sizes)
  639. patch_attention_mask = patch_attention_mask.view(batch_size, -1)
  640. # The call to `_upad_input` in `_flash_attention_forward` is expensive
  641. # So when the `patch_attention_mask` is full of 1s
  642. # (i.e. attending to the whole sequence),
  643. # avoiding passing the attention_mask,
  644. # which is equivalent to attending to the full sequence
  645. if not torch.any(~patch_attention_mask):
  646. attention_mask = None
  647. else:
  648. attention_mask = (_prepare_4d_attention_mask(
  649. patch_attention_mask, hidden_states.dtype)
  650. if not self._use_flash_attention_2 else
  651. patch_attention_mask)
  652. encoder_outputs = self.encoder(
  653. inputs_embeds=hidden_states,
  654. attention_mask=attention_mask,
  655. output_attentions=output_attentions,
  656. output_hidden_states=output_hidden_states,
  657. return_dict=return_dict,
  658. )
  659. last_hidden_state = encoder_outputs[0]
  660. last_hidden_state = self.post_layernorm(last_hidden_state)
  661. if not return_dict:
  662. return (last_hidden_state, None) + encoder_outputs[1:]
  663. return BaseModelOutputWithPooling(
  664. last_hidden_state=last_hidden_state,
  665. pooler_output=None,
  666. hidden_states=encoder_outputs.hidden_states,
  667. attentions=encoder_outputs.attentions,
  668. )