siglip.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. """Implementation of SiglipVisionModel intended to be only used
  2. within a vision language model."""
  3. import math
  4. from typing import Iterable, List, Optional, Tuple, Union
  5. import numpy as np
  6. import torch
  7. from PIL import Image
  8. from torch import nn
  9. from transformers import SiglipVisionConfig
  10. from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
  11. from aphrodite.common.config import ModelConfig
  12. from aphrodite.common.sequence import SequenceData
  13. from aphrodite.distributed import divide, get_tensor_model_parallel_world_size
  14. from aphrodite.inputs import LLMInputs
  15. from aphrodite.modeling.layers.activation import get_act_fn
  16. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  17. QKVParallelLinear,
  18. RowParallelLinear)
  19. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  20. VocabParallelEmbedding)
  21. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  22. from aphrodite.multimodal.utils import (cached_get_tokenizer,
  23. repeat_and_pad_placeholder_tokens)
  24. from aphrodite.quantization import QuantizationConfig
  25. try:
  26. from xformers import ops as xops
  27. USE_XFORMERS_OPS = True
  28. except ImportError:
  29. USE_XFORMERS_OPS = False
  30. def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
  31. # Since interpolation is applied, the image size need not be divisible
  32. # assert image_size % patch_size == 0
  33. return image_size // patch_size
  34. def get_siglip_num_patches(*, image_size: int, patch_size: int) -> int:
  35. grid_length = get_siglip_patch_grid_length(image_size=image_size,
  36. patch_size=patch_size)
  37. return grid_length * grid_length
  38. def get_siglip_image_feature_size(hf_config: SiglipVisionConfig) -> int:
  39. return get_siglip_num_patches(image_size=hf_config.image_size,
  40. patch_size=hf_config.patch_size)
  41. def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int:
  42. return get_siglip_image_feature_size(hf_config)
  43. def dummy_seq_data_for_siglip(
  44. hf_config: SiglipVisionConfig,
  45. seq_len: int,
  46. num_images: int,
  47. *,
  48. image_token_id: int,
  49. image_feature_size_override: Optional[int] = None,
  50. ):
  51. if image_feature_size_override is None:
  52. image_feature_size = get_siglip_image_feature_size(hf_config)
  53. else:
  54. image_feature_size = image_feature_size_override
  55. return SequenceData.from_token_counts(
  56. (image_token_id, image_feature_size * num_images),
  57. (0, seq_len - image_feature_size * num_images),
  58. )
  59. def dummy_image_for_siglip(
  60. hf_config: SiglipVisionConfig,
  61. num_images: int,
  62. *,
  63. image_width_override: Optional[int] = None,
  64. image_height_override: Optional[int] = None,
  65. ):
  66. width = height = hf_config.image_size
  67. if image_width_override is not None:
  68. width = image_width_override
  69. if image_height_override is not None:
  70. height = image_height_override
  71. image = Image.new("RGB", (width, height), color=0)
  72. return {"image": image if num_images == 1 else [image] * num_images}
  73. def dummy_video_for_siglip(
  74. hf_config: SiglipVisionConfig,
  75. num_frames: int,
  76. *,
  77. image_width_override: Optional[int] = None,
  78. image_height_override: Optional[int] = None,
  79. ):
  80. pil_frame = dummy_image_for_siglip(
  81. hf_config,
  82. num_images=1,
  83. image_width_override=image_width_override,
  84. image_height_override=image_height_override)
  85. np_frame = np.array(pil_frame["image"])
  86. mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
  87. mm_data = {"video": mm_data_per_video}
  88. return mm_data
  89. def input_processor_for_siglip(
  90. model_config: ModelConfig,
  91. hf_config: SiglipVisionConfig,
  92. llm_inputs: LLMInputs,
  93. *,
  94. image_token_id: int,
  95. image_feature_size_override: Optional[Union[int, List[int]]] = None,
  96. ):
  97. multi_modal_data = llm_inputs.get("multi_modal_data")
  98. if multi_modal_data is None or "image" not in multi_modal_data:
  99. return llm_inputs
  100. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  101. if image_feature_size_override is None:
  102. image_data = multi_modal_data["image"]
  103. if isinstance(image_data, Image.Image):
  104. image_feature_size = get_siglip_image_feature_size(hf_config)
  105. elif isinstance(image_data, torch.Tensor):
  106. num_images, image_feature_size, hidden_size = image_data.shape
  107. else:
  108. raise TypeError(f"Invalid image type: {type(image_data)}")
  109. else:
  110. image_feature_size = image_feature_size_override
  111. new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
  112. tokenizer,
  113. llm_inputs.get("prompt"),
  114. llm_inputs["prompt_token_ids"],
  115. placeholder_token_id=image_token_id,
  116. repeat_count=image_feature_size,
  117. )
  118. # NOTE: Create a defensive copy of the original inputs
  119. return LLMInputs(
  120. prompt_token_ids=new_token_ids,
  121. prompt=new_prompt,
  122. multi_modal_data=multi_modal_data,
  123. )
  124. # Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
  125. class SiglipVisionEmbeddings(nn.Module):
  126. def __init__(self, config: SiglipVisionConfig):
  127. super().__init__()
  128. self.config = config
  129. self.embed_dim = config.hidden_size
  130. self.image_size = config.image_size
  131. self.patch_size = config.patch_size
  132. self.patch_embedding = nn.Conv2d(
  133. in_channels=config.num_channels,
  134. out_channels=self.embed_dim,
  135. kernel_size=self.patch_size,
  136. stride=self.patch_size,
  137. padding="valid",
  138. )
  139. self.num_patches = (self.image_size // self.patch_size)**2
  140. self.num_positions = self.num_patches
  141. self.position_embedding = VocabParallelEmbedding(
  142. self.num_positions, self.embed_dim)
  143. self.register_buffer(
  144. "position_ids",
  145. torch.arange(self.num_positions, dtype=torch.int64).expand(
  146. (1, -1)),
  147. persistent=False,
  148. )
  149. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int,
  150. width: int) -> torch.Tensor:
  151. """
  152. This method is an adapted method for SigLIP (due to SigLIP not having
  153. class embedding unlike other ViTs) that allows the model to interpolate
  154. the pre-trained position encodings such that it can be usable on higher
  155. resolution images.
  156. Source:
  157. https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
  158. """
  159. position_embeddings = self.position_embedding.weight.unsqueeze(0)
  160. num_patches = embeddings.shape[1]
  161. num_positions = position_embeddings.shape[1]
  162. if num_patches == num_positions and height == width:
  163. return position_embeddings
  164. dim = embeddings.shape[-1]
  165. height = height // self.patch_size
  166. width = width // self.patch_size
  167. # we add a small number to avoid floating point error
  168. # in the interpolation
  169. # see discussion at https://github.com/facebookresearch/dino/issues/8
  170. height, width = height + 0.1, width + 0.1
  171. patch_pos_embed = position_embeddings.reshape(
  172. 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)),
  173. dim)
  174. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  175. patch_pos_embed = nn.functional.interpolate(
  176. patch_pos_embed,
  177. scale_factor=(
  178. height / math.sqrt(num_positions),
  179. width / math.sqrt(num_positions),
  180. ),
  181. mode="bicubic",
  182. align_corners=False,
  183. )
  184. if (int(height) != patch_pos_embed.shape[-2]
  185. or int(width) != patch_pos_embed.shape[-1]):
  186. raise ValueError("Width or height does not match with "
  187. "the interpolated position embeddings")
  188. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  189. return patch_pos_embed
  190. def forward(self,
  191. pixel_values: torch.Tensor,
  192. interpolate_pos_encoding: bool = False) -> torch.Tensor:
  193. _, _, height, width = pixel_values.shape
  194. target_dtype = self.patch_embedding.weight.dtype
  195. patch_embeds = self.patch_embedding(pixel_values.to(
  196. dtype=target_dtype)) # shape = [*, width, grid, grid]
  197. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  198. if interpolate_pos_encoding:
  199. embeddings = embeddings + self.interpolate_pos_encoding(
  200. embeddings, height, width)
  201. else:
  202. embeddings = embeddings + self.position_embedding(
  203. self.position_ids)
  204. return embeddings
  205. class SiglipParallelAttention(nn.Module):
  206. def __init__(
  207. self,
  208. config,
  209. quant_config: Optional[QuantizationConfig] = None,
  210. ):
  211. super().__init__()
  212. self.config = config
  213. self.embed_dim = config.hidden_size
  214. self.num_heads = config.num_attention_heads
  215. self.head_dim = self.embed_dim // self.num_heads
  216. if self.head_dim * self.num_heads != self.embed_dim:
  217. raise ValueError(f"embed_dim must be divisible by num_heads (got "
  218. f"`embed_dim`: {self.embed_dim} and `num_heads`:"
  219. f" {self.num_heads}).")
  220. self.scale = self.head_dim**-0.5
  221. self.dropout = config.attention_dropout
  222. self.qkv_proj = QKVParallelLinear(
  223. hidden_size=self.embed_dim,
  224. head_size=self.head_dim,
  225. total_num_heads=self.num_heads,
  226. quant_config=quant_config,
  227. )
  228. self.out_proj = RowParallelLinear(
  229. input_size=self.embed_dim,
  230. output_size=self.embed_dim,
  231. quant_config=quant_config,
  232. )
  233. self.tp_size = get_tensor_model_parallel_world_size()
  234. self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
  235. def forward(
  236. self,
  237. hidden_states: torch.Tensor,
  238. ) -> torch.Tensor:
  239. """Input shape: Batch x Time x Channel"""
  240. batch_size, q_len, _ = hidden_states.size()
  241. qkv_states, _ = self.qkv_proj(hidden_states)
  242. query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
  243. query_states = query_states.view(batch_size, q_len,
  244. self.num_heads_per_partition,
  245. self.head_dim)
  246. key_states = key_states.view(batch_size, q_len,
  247. self.num_heads_per_partition,
  248. self.head_dim)
  249. value_states = value_states.view(batch_size, q_len,
  250. self.num_heads_per_partition,
  251. self.head_dim)
  252. out = xops.memory_efficient_attention_forward(query_states,
  253. key_states,
  254. value_states,
  255. p=self.dropout,
  256. scale=self.scale)
  257. out = out.view(batch_size, q_len, -1)
  258. attn_output, _ = self.out_proj(out)
  259. return attn_output, None
  260. class SiglipMLP(nn.Module):
  261. def __init__(
  262. self,
  263. config,
  264. quant_config: Optional[QuantizationConfig] = None,
  265. ):
  266. super().__init__()
  267. self.config = config
  268. self.activation_fn = get_act_fn(config.hidden_act)
  269. # For quantization, we require the hidden size to be a multiple of 64
  270. quantizable = (config.hidden_size % 64 == 0
  271. and config.intermediate_size % 64 == 0)
  272. self.fc1 = ColumnParallelLinear(
  273. config.hidden_size,
  274. config.intermediate_size,
  275. quant_config=quant_config if quantizable else None,
  276. )
  277. self.fc2 = RowParallelLinear(
  278. config.intermediate_size,
  279. config.hidden_size,
  280. quant_config=quant_config if quantizable else None,
  281. )
  282. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  283. hidden_states, _ = self.fc1(hidden_states)
  284. hidden_states = self.activation_fn(hidden_states)
  285. hidden_states, _ = self.fc2(hidden_states)
  286. return hidden_states
  287. class SiglipEncoderLayer(nn.Module):
  288. def __init__(
  289. self,
  290. config: SiglipVisionConfig,
  291. quant_config: Optional[QuantizationConfig] = None,
  292. ):
  293. super().__init__()
  294. self.embed_dim = config.hidden_size
  295. num_heads = config.num_attention_heads
  296. tp_size = get_tensor_model_parallel_world_size()
  297. if USE_XFORMERS_OPS and num_heads % tp_size == 0:
  298. self.self_attn = SiglipParallelAttention(config,
  299. quant_config=quant_config)
  300. else:
  301. self.self_attn = SiglipSdpaAttention(config)
  302. self.layer_norm1 = nn.LayerNorm(self.embed_dim,
  303. eps=config.layer_norm_eps)
  304. self.mlp = SiglipMLP(
  305. config,
  306. quant_config=quant_config,
  307. )
  308. self.layer_norm2 = nn.LayerNorm(self.embed_dim,
  309. eps=config.layer_norm_eps)
  310. def forward(
  311. self,
  312. hidden_states: torch.Tensor,
  313. ) -> Tuple[torch.Tensor, None]:
  314. residual = hidden_states
  315. hidden_states = self.layer_norm1(hidden_states)
  316. hidden_states, _ = self.self_attn(hidden_states=hidden_states)
  317. hidden_states = residual + hidden_states
  318. residual = hidden_states
  319. hidden_states = self.layer_norm2(hidden_states)
  320. hidden_states = self.mlp(hidden_states)
  321. hidden_states = residual + hidden_states
  322. return hidden_states, None
  323. class SiglipEncoder(nn.Module):
  324. def __init__(
  325. self,
  326. config: SiglipVisionConfig,
  327. quant_config: Optional[QuantizationConfig] = None,
  328. num_hidden_layers_override: Optional[int] = None,
  329. ):
  330. super().__init__()
  331. self.config = config
  332. if num_hidden_layers_override is None:
  333. num_hidden_layers = config.num_hidden_layers
  334. else:
  335. num_hidden_layers = num_hidden_layers_override
  336. self.layers = nn.ModuleList([
  337. SiglipEncoderLayer(config, quant_config=quant_config)
  338. for _ in range(num_hidden_layers)
  339. ])
  340. def forward(
  341. self,
  342. inputs_embeds: torch.Tensor,
  343. ) -> torch.Tensor:
  344. hidden_states = inputs_embeds
  345. for encoder_layer in self.layers:
  346. hidden_states, _ = encoder_layer(hidden_states)
  347. return hidden_states
  348. class SiglipMultiheadAttentionPoolingHead(nn.Module):
  349. """Multihead Attention Pooling."""
  350. def __init__(
  351. self,
  352. config: SiglipVisionConfig,
  353. quant_config: Optional[QuantizationConfig] = None,
  354. ):
  355. super().__init__()
  356. self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  357. # TODO(ChristopherCho): Implement aphrodite version of MHA
  358. self.attention = torch.nn.MultiheadAttention(
  359. config.hidden_size, config.num_attention_heads, batch_first=True)
  360. self.layernorm = nn.LayerNorm(config.hidden_size,
  361. eps=config.layer_norm_eps)
  362. self.mlp = SiglipMLP(config=config, quant_config=quant_config)
  363. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  364. batch_size = hidden_state.shape[0]
  365. probe = self.probe.repeat(batch_size, 1, 1)
  366. hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
  367. residual = hidden_state
  368. hidden_state = self.layernorm(hidden_state)
  369. hidden_state = residual + self.mlp(hidden_state)
  370. return hidden_state[:, 0]
  371. class SiglipVisionTransformer(nn.Module):
  372. def __init__(
  373. self,
  374. config: SiglipVisionConfig,
  375. quant_config: Optional[QuantizationConfig] = None,
  376. num_hidden_layers_override: Optional[int] = None,
  377. ):
  378. super().__init__()
  379. self.config = config
  380. embed_dim = config.hidden_size
  381. self.embeddings = SiglipVisionEmbeddings(config)
  382. self.encoder = SiglipEncoder(
  383. config,
  384. quant_config=quant_config,
  385. num_hidden_layers_override=num_hidden_layers_override,
  386. )
  387. if len(self.encoder.layers) > config.num_hidden_layers:
  388. raise ValueError(
  389. f"The original encoder only has {config.num_hidden_layers} "
  390. f"layers, but you requested {len(self.encoder.layers)} layers."
  391. )
  392. elif len(self.encoder.layers) == config.num_hidden_layers:
  393. self.post_layernorm = nn.LayerNorm(embed_dim,
  394. eps=config.layer_norm_eps)
  395. else:
  396. # post_layernorm is unused when we extract intermediate features
  397. # In this case, we can skip it to conserve memory
  398. self.post_layernorm = None
  399. self.use_head = (True if not hasattr(config, "vision_use_head") else
  400. config.vision_use_head)
  401. if self.use_head:
  402. self.head = SiglipMultiheadAttentionPoolingHead(
  403. config=config, quant_config=quant_config)
  404. def forward(
  405. self,
  406. pixel_values: torch.Tensor,
  407. interpolate_pos_encoding: bool = True,
  408. ) -> torch.Tensor:
  409. hidden_states = self.embeddings(
  410. pixel_values,
  411. interpolate_pos_encoding=interpolate_pos_encoding,
  412. )
  413. encoder_outputs = self.encoder(inputs_embeds=hidden_states)
  414. if self.post_layernorm is None:
  415. return encoder_outputs
  416. last_hidden_state = self.post_layernorm(encoder_outputs)
  417. # TODO: add this back when pooled_output is used in inference
  418. # if self.use_head:
  419. # pooled_output = self.head(last_hidden_state)
  420. return last_hidden_state
  421. class SiglipVisionModel(nn.Module):
  422. config_class = SiglipVisionConfig
  423. main_input_name = "pixel_values"
  424. def __init__(
  425. self,
  426. config: SiglipVisionConfig,
  427. quant_config: Optional[QuantizationConfig] = None,
  428. num_hidden_layers_override: Optional[int] = None,
  429. ):
  430. super().__init__()
  431. num_heads = config.num_attention_heads
  432. tp_size = get_tensor_model_parallel_world_size()
  433. self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
  434. self.vision_model = SiglipVisionTransformer(
  435. config,
  436. quant_config,
  437. num_hidden_layers_override=num_hidden_layers_override,
  438. )
  439. def get_input_embeddings(self) -> nn.Module:
  440. return self.vision_model.embeddings.patch_embedding
  441. def forward(
  442. self,
  443. pixel_values: torch.Tensor,
  444. interpolate_pos_encoding: bool = False,
  445. ) -> torch.Tensor:
  446. return self.vision_model(
  447. pixel_values=pixel_values,
  448. interpolate_pos_encoding=interpolate_pos_encoding,
  449. )
  450. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  451. stacked_params_mapping = [
  452. # (param_name, shard_name, shard_id)
  453. ("qkv_proj", "q_proj", "q"),
  454. ("qkv_proj", "k_proj", "k"),
  455. ("qkv_proj", "v_proj", "v"),
  456. ] if self.shard_weight else []
  457. params_dict = dict(self.named_parameters())
  458. layer_count = len(self.vision_model.encoder.layers)
  459. for name, loaded_weight in weights:
  460. # post_layernorm is optional in SiglipVisionModel
  461. if (name.startswith("vision_model.post_layernorm")
  462. and self.vision_model.post_layernorm is None):
  463. continue
  464. # omit layers when num_hidden_layers_override is set
  465. if name.startswith("vision_model.encoder.layers"):
  466. layer_idx = int(name.split(".")[3])
  467. if layer_idx >= layer_count:
  468. continue
  469. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  470. if weight_name not in name:
  471. continue
  472. param = params_dict[name.replace(weight_name, param_name)]
  473. weight_loader = param.weight_loader
  474. weight_loader(param, loaded_weight, shard_id)
  475. break
  476. else:
  477. param = params_dict[name]
  478. weight_loader = getattr(param, "weight_loader",
  479. default_weight_loader)
  480. weight_loader(param, loaded_weight)