siglip.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. """Implementation of SiglipVisionModel intended to be only used
  2. within a vision language model."""
  3. import math
  4. from typing import Iterable, List, Optional, Tuple, Union
  5. import torch
  6. from PIL import Image
  7. from torch import nn
  8. from transformers import SiglipVisionConfig
  9. from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
  10. from aphrodite.common.config import ModelConfig
  11. from aphrodite.common.sequence import SequenceData
  12. from aphrodite.distributed import divide, get_tensor_model_parallel_world_size
  13. from aphrodite.inputs import LLMInputs
  14. from aphrodite.modeling.layers.activation import get_act_fn
  15. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  16. QKVParallelLinear,
  17. RowParallelLinear)
  18. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  19. VocabParallelEmbedding)
  20. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  21. from aphrodite.multimodal.utils import (cached_get_tokenizer,
  22. repeat_and_pad_placeholder_tokens)
  23. from aphrodite.quantization import QuantizationConfig
  24. try:
  25. from xformers import ops as xops
  26. USE_XFORMERS_OPS = True
  27. except ImportError:
  28. USE_XFORMERS_OPS = False
  29. def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
  30. # Since interpolation is applied, the image size need not be divisible
  31. # assert image_size % patch_size == 0
  32. return image_size // patch_size
  33. def get_siglip_num_patches(*, image_size: int, patch_size: int) -> int:
  34. grid_length = get_siglip_patch_grid_length(image_size=image_size,
  35. patch_size=patch_size)
  36. return grid_length * grid_length
  37. def get_siglip_image_feature_size(hf_config: SiglipVisionConfig) -> int:
  38. return get_siglip_num_patches(image_size=hf_config.image_size,
  39. patch_size=hf_config.patch_size)
  40. def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int:
  41. return get_siglip_image_feature_size(hf_config)
  42. def dummy_seq_data_for_siglip(
  43. hf_config: SiglipVisionConfig,
  44. seq_len: int,
  45. num_images: int,
  46. *,
  47. image_token_id: int,
  48. image_feature_size_override: Optional[int] = None,
  49. ):
  50. if image_feature_size_override is None:
  51. image_feature_size = get_siglip_image_feature_size(hf_config)
  52. else:
  53. image_feature_size = image_feature_size_override
  54. return SequenceData.from_token_counts(
  55. (image_token_id, image_feature_size * num_images),
  56. (0, seq_len - image_feature_size * num_images),
  57. )
  58. def dummy_image_for_siglip(
  59. hf_config: SiglipVisionConfig,
  60. num_images: int,
  61. *,
  62. image_width_override: Optional[int] = None,
  63. image_height_override: Optional[int] = None,
  64. ):
  65. width = height = hf_config.image_size
  66. if image_width_override is not None:
  67. width = image_width_override
  68. if image_height_override is not None:
  69. height = image_height_override
  70. image = Image.new("RGB", (width, height), color=0)
  71. return {"image": image if num_images == 1 else [image] * num_images}
  72. def input_processor_for_siglip(
  73. model_config: ModelConfig,
  74. hf_config: SiglipVisionConfig,
  75. llm_inputs: LLMInputs,
  76. *,
  77. image_token_id: int,
  78. image_feature_size_override: Optional[Union[int, List[int]]] = None,
  79. ):
  80. multi_modal_data = llm_inputs.get("multi_modal_data")
  81. if multi_modal_data is None or "image" not in multi_modal_data:
  82. return llm_inputs
  83. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  84. if image_feature_size_override is None:
  85. image_data = multi_modal_data["image"]
  86. if isinstance(image_data, Image.Image):
  87. image_feature_size = get_siglip_image_feature_size(hf_config)
  88. elif isinstance(image_data, torch.Tensor):
  89. num_images, image_feature_size, hidden_size = image_data.shape
  90. else:
  91. raise TypeError(f"Invalid image type: {type(image_data)}")
  92. else:
  93. image_feature_size = image_feature_size_override
  94. new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
  95. tokenizer,
  96. llm_inputs.get("prompt"),
  97. llm_inputs["prompt_token_ids"],
  98. placeholder_token_id=image_token_id,
  99. repeat_count=image_feature_size,
  100. )
  101. # NOTE: Create a defensive copy of the original inputs
  102. return LLMInputs(
  103. prompt_token_ids=new_token_ids,
  104. prompt=new_prompt,
  105. multi_modal_data=multi_modal_data,
  106. )
  107. # Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
  108. class SiglipVisionEmbeddings(nn.Module):
  109. def __init__(self, config: SiglipVisionConfig):
  110. super().__init__()
  111. self.config = config
  112. self.embed_dim = config.hidden_size
  113. self.image_size = config.image_size
  114. self.patch_size = config.patch_size
  115. self.patch_embedding = nn.Conv2d(
  116. in_channels=config.num_channels,
  117. out_channels=self.embed_dim,
  118. kernel_size=self.patch_size,
  119. stride=self.patch_size,
  120. padding="valid",
  121. )
  122. self.num_patches = (self.image_size // self.patch_size)**2
  123. self.num_positions = self.num_patches
  124. self.position_embedding = VocabParallelEmbedding(
  125. self.num_positions, self.embed_dim)
  126. self.register_buffer(
  127. "position_ids",
  128. torch.arange(self.num_positions, dtype=torch.int64).expand(
  129. (1, -1)),
  130. persistent=False,
  131. )
  132. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int,
  133. width: int) -> torch.Tensor:
  134. """
  135. This method is an adapted method for SigLIP (due to SigLIP not having
  136. class embedding unlike other ViTs) that allows the model to interpolate
  137. the pre-trained position encodings such that it can be usable on higher
  138. resolution images.
  139. Source:
  140. https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
  141. """
  142. position_embeddings = self.position_embedding.weight.unsqueeze(0)
  143. num_patches = embeddings.shape[1]
  144. num_positions = position_embeddings.shape[1]
  145. if num_patches == num_positions and height == width:
  146. return position_embeddings
  147. dim = embeddings.shape[-1]
  148. height = height // self.patch_size
  149. width = width // self.patch_size
  150. # we add a small number to avoid floating point error
  151. # in the interpolation
  152. # see discussion at https://github.com/facebookresearch/dino/issues/8
  153. height, width = height + 0.1, width + 0.1
  154. patch_pos_embed = position_embeddings.reshape(
  155. 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)),
  156. dim)
  157. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  158. patch_pos_embed = nn.functional.interpolate(
  159. patch_pos_embed,
  160. scale_factor=(
  161. height / math.sqrt(num_positions),
  162. width / math.sqrt(num_positions),
  163. ),
  164. mode="bicubic",
  165. align_corners=False,
  166. )
  167. if (int(height) != patch_pos_embed.shape[-2]
  168. or int(width) != patch_pos_embed.shape[-1]):
  169. raise ValueError("Width or height does not match with "
  170. "the interpolated position embeddings")
  171. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  172. return patch_pos_embed
  173. def forward(self,
  174. pixel_values: torch.Tensor,
  175. interpolate_pos_encoding: bool = False) -> torch.Tensor:
  176. _, _, height, width = pixel_values.shape
  177. target_dtype = self.patch_embedding.weight.dtype
  178. patch_embeds = self.patch_embedding(pixel_values.to(
  179. dtype=target_dtype)) # shape = [*, width, grid, grid]
  180. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  181. if interpolate_pos_encoding:
  182. embeddings = embeddings + self.interpolate_pos_encoding(
  183. embeddings, height, width)
  184. else:
  185. embeddings = embeddings + self.position_embedding(
  186. self.position_ids)
  187. return embeddings
  188. class SiglipParallelAttention(nn.Module):
  189. def __init__(
  190. self,
  191. config,
  192. quant_config: Optional[QuantizationConfig] = None,
  193. ):
  194. super().__init__()
  195. self.config = config
  196. self.embed_dim = config.hidden_size
  197. self.num_heads = config.num_attention_heads
  198. self.head_dim = self.embed_dim // self.num_heads
  199. if self.head_dim * self.num_heads != self.embed_dim:
  200. raise ValueError(f"embed_dim must be divisible by num_heads (got "
  201. f"`embed_dim`: {self.embed_dim} and `num_heads`:"
  202. f" {self.num_heads}).")
  203. self.scale = self.head_dim**-0.5
  204. self.dropout = config.attention_dropout
  205. self.qkv_proj = QKVParallelLinear(
  206. hidden_size=self.embed_dim,
  207. head_size=self.head_dim,
  208. total_num_heads=self.num_heads,
  209. quant_config=quant_config,
  210. )
  211. self.out_proj = RowParallelLinear(
  212. input_size=self.embed_dim,
  213. output_size=self.embed_dim,
  214. quant_config=quant_config,
  215. )
  216. self.tp_size = get_tensor_model_parallel_world_size()
  217. self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
  218. def forward(
  219. self,
  220. hidden_states: torch.Tensor,
  221. ) -> torch.Tensor:
  222. """Input shape: Batch x Time x Channel"""
  223. batch_size, q_len, _ = hidden_states.size()
  224. qkv_states, _ = self.qkv_proj(hidden_states)
  225. query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
  226. query_states = query_states.view(batch_size, q_len,
  227. self.num_heads_per_partition,
  228. self.head_dim)
  229. key_states = key_states.view(batch_size, q_len,
  230. self.num_heads_per_partition,
  231. self.head_dim)
  232. value_states = value_states.view(batch_size, q_len,
  233. self.num_heads_per_partition,
  234. self.head_dim)
  235. out = xops.memory_efficient_attention_forward(query_states,
  236. key_states,
  237. value_states,
  238. p=self.dropout,
  239. scale=self.scale)
  240. out = out.view(batch_size, q_len, -1)
  241. attn_output, _ = self.out_proj(out)
  242. return attn_output, None
  243. class SiglipMLP(nn.Module):
  244. def __init__(
  245. self,
  246. config,
  247. quant_config: Optional[QuantizationConfig] = None,
  248. ):
  249. super().__init__()
  250. self.config = config
  251. self.activation_fn = get_act_fn(config.hidden_act)
  252. # For quantization, we require the hidden size to be a multiple of 64
  253. quantizable = (config.hidden_size % 64 == 0
  254. and config.intermediate_size % 64 == 0)
  255. self.fc1 = ColumnParallelLinear(
  256. config.hidden_size,
  257. config.intermediate_size,
  258. quant_config=quant_config if quantizable else None,
  259. )
  260. self.fc2 = RowParallelLinear(
  261. config.intermediate_size,
  262. config.hidden_size,
  263. quant_config=quant_config if quantizable else None,
  264. )
  265. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  266. hidden_states, _ = self.fc1(hidden_states)
  267. hidden_states = self.activation_fn(hidden_states)
  268. hidden_states, _ = self.fc2(hidden_states)
  269. return hidden_states
  270. class SiglipEncoderLayer(nn.Module):
  271. def __init__(
  272. self,
  273. config: SiglipVisionConfig,
  274. quant_config: Optional[QuantizationConfig] = None,
  275. ):
  276. super().__init__()
  277. self.embed_dim = config.hidden_size
  278. num_heads = config.num_attention_heads
  279. tp_size = get_tensor_model_parallel_world_size()
  280. if USE_XFORMERS_OPS and num_heads % tp_size == 0:
  281. self.self_attn = SiglipParallelAttention(config,
  282. quant_config=quant_config)
  283. else:
  284. self.self_attn = SiglipSdpaAttention(config)
  285. self.layer_norm1 = nn.LayerNorm(self.embed_dim,
  286. eps=config.layer_norm_eps)
  287. self.mlp = SiglipMLP(
  288. config,
  289. quant_config=quant_config,
  290. )
  291. self.layer_norm2 = nn.LayerNorm(self.embed_dim,
  292. eps=config.layer_norm_eps)
  293. def forward(
  294. self,
  295. hidden_states: torch.Tensor,
  296. ) -> Tuple[torch.Tensor, None]:
  297. residual = hidden_states
  298. hidden_states = self.layer_norm1(hidden_states)
  299. hidden_states, _ = self.self_attn(hidden_states=hidden_states)
  300. hidden_states = residual + hidden_states
  301. residual = hidden_states
  302. hidden_states = self.layer_norm2(hidden_states)
  303. hidden_states = self.mlp(hidden_states)
  304. hidden_states = residual + hidden_states
  305. return hidden_states, None
  306. class SiglipEncoder(nn.Module):
  307. def __init__(
  308. self,
  309. config: SiglipVisionConfig,
  310. quant_config: Optional[QuantizationConfig] = None,
  311. num_hidden_layers_override: Optional[int] = None,
  312. ):
  313. super().__init__()
  314. self.config = config
  315. if num_hidden_layers_override is None:
  316. num_hidden_layers = config.num_hidden_layers
  317. else:
  318. num_hidden_layers = num_hidden_layers_override
  319. self.layers = nn.ModuleList([
  320. SiglipEncoderLayer(config, quant_config=quant_config)
  321. for _ in range(num_hidden_layers)
  322. ])
  323. def forward(
  324. self,
  325. inputs_embeds: torch.Tensor,
  326. ) -> torch.Tensor:
  327. hidden_states = inputs_embeds
  328. for encoder_layer in self.layers:
  329. hidden_states, _ = encoder_layer(hidden_states)
  330. return hidden_states
  331. class SiglipMultiheadAttentionPoolingHead(nn.Module):
  332. """Multihead Attention Pooling."""
  333. def __init__(
  334. self,
  335. config: SiglipVisionConfig,
  336. quant_config: Optional[QuantizationConfig] = None,
  337. ):
  338. super().__init__()
  339. self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  340. # TODO(ChristopherCho): Implement aphrodite version of MHA
  341. self.attention = torch.nn.MultiheadAttention(
  342. config.hidden_size, config.num_attention_heads, batch_first=True)
  343. self.layernorm = nn.LayerNorm(config.hidden_size,
  344. eps=config.layer_norm_eps)
  345. self.mlp = SiglipMLP(config=config, quant_config=quant_config)
  346. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  347. batch_size = hidden_state.shape[0]
  348. probe = self.probe.repeat(batch_size, 1, 1)
  349. hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
  350. residual = hidden_state
  351. hidden_state = self.layernorm(hidden_state)
  352. hidden_state = residual + self.mlp(hidden_state)
  353. return hidden_state[:, 0]
  354. class SiglipVisionTransformer(nn.Module):
  355. def __init__(
  356. self,
  357. config: SiglipVisionConfig,
  358. quant_config: Optional[QuantizationConfig] = None,
  359. num_hidden_layers_override: Optional[int] = None,
  360. ):
  361. super().__init__()
  362. self.config = config
  363. embed_dim = config.hidden_size
  364. self.embeddings = SiglipVisionEmbeddings(config)
  365. self.encoder = SiglipEncoder(
  366. config,
  367. quant_config=quant_config,
  368. num_hidden_layers_override=num_hidden_layers_override,
  369. )
  370. if len(self.encoder.layers) > config.num_hidden_layers:
  371. raise ValueError(
  372. f"The original encoder only has {config.num_hidden_layers} "
  373. f"layers, but you requested {len(self.encoder.layers)} layers."
  374. )
  375. elif len(self.encoder.layers) == config.num_hidden_layers:
  376. self.post_layernorm = nn.LayerNorm(embed_dim,
  377. eps=config.layer_norm_eps)
  378. else:
  379. # post_layernorm is unused when we extract intermediate features
  380. # In this case, we can skip it to conserve memory
  381. self.post_layernorm = None
  382. self.use_head = (True if not hasattr(config, "vision_use_head") else
  383. config.vision_use_head)
  384. if self.use_head:
  385. self.head = SiglipMultiheadAttentionPoolingHead(
  386. config=config, quant_config=quant_config)
  387. def forward(
  388. self,
  389. pixel_values: torch.Tensor,
  390. interpolate_pos_encoding: bool = True,
  391. ) -> torch.Tensor:
  392. hidden_states = self.embeddings(
  393. pixel_values,
  394. interpolate_pos_encoding=interpolate_pos_encoding,
  395. )
  396. encoder_outputs = self.encoder(inputs_embeds=hidden_states)
  397. if self.post_layernorm is None:
  398. return encoder_outputs
  399. last_hidden_state = self.post_layernorm(encoder_outputs)
  400. # TODO: add this back when pooled_output is used in inference
  401. # if self.use_head:
  402. # pooled_output = self.head(last_hidden_state)
  403. return last_hidden_state
  404. class SiglipVisionModel(nn.Module):
  405. config_class = SiglipVisionConfig
  406. main_input_name = "pixel_values"
  407. def __init__(
  408. self,
  409. config: SiglipVisionConfig,
  410. quant_config: Optional[QuantizationConfig] = None,
  411. num_hidden_layers_override: Optional[int] = None,
  412. ):
  413. super().__init__()
  414. num_heads = config.num_attention_heads
  415. tp_size = get_tensor_model_parallel_world_size()
  416. self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
  417. self.vision_model = SiglipVisionTransformer(
  418. config,
  419. quant_config,
  420. num_hidden_layers_override=num_hidden_layers_override,
  421. )
  422. @property
  423. def _require_post_layernorm(self) -> bool:
  424. return self.vision_model.post_layernorm is not None
  425. def get_input_embeddings(self) -> nn.Module:
  426. return self.vision_model.embeddings.patch_embedding
  427. def forward(
  428. self,
  429. pixel_values: torch.Tensor,
  430. interpolate_pos_encoding: bool = False,
  431. ) -> torch.Tensor:
  432. return self.vision_model(
  433. pixel_values=pixel_values,
  434. interpolate_pos_encoding=interpolate_pos_encoding,
  435. )
  436. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  437. stacked_params_mapping = [
  438. # (param_name, shard_name, shard_id)
  439. ("qkv_proj", "q_proj", "q"),
  440. ("qkv_proj", "k_proj", "k"),
  441. ("qkv_proj", "v_proj", "v"),
  442. ] if self.shard_weight else []
  443. params_dict = dict(self.named_parameters())
  444. layer_count = len(self.vision_model.encoder.layers)
  445. for name, loaded_weight in weights:
  446. # post_layernorm is optional in SiglipVisionModel
  447. if ("vision_model.post_layernorm" in name
  448. and not self._require_post_layernorm):
  449. continue
  450. # omit layers when num_hidden_layers_override is set
  451. if "vision_model.encoder.layers." in name:
  452. layer_idx = int(name.split(".")[3])
  453. if layer_idx >= layer_count:
  454. continue
  455. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  456. if weight_name not in name:
  457. continue
  458. param = params_dict[name.replace(weight_name, param_name)]
  459. weight_loader = param.weight_loader
  460. weight_loader(param, loaded_weight, shard_id)
  461. break
  462. else:
  463. param = params_dict[name]
  464. weight_loader = getattr(param, "weight_loader",
  465. default_weight_loader)
  466. weight_loader(param, loaded_weight)