siglip.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
  1. """Implementation of SiglipVisionModel intended to be only used
  2. within a vision language model."""
  3. import math
  4. from array import array
  5. from typing import Iterable, List, Optional, Tuple, Union
  6. import torch
  7. from PIL import Image
  8. from torch import nn
  9. from transformers import SiglipVisionConfig
  10. from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
  11. from aphrodite.common.config import ModelConfig
  12. from aphrodite.common.sequence import SequenceData
  13. from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
  14. from aphrodite.distributed import divide, get_tensor_model_parallel_world_size
  15. from aphrodite.inputs import LLMInputs
  16. from aphrodite.modeling.layers.activation import get_act_fn
  17. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  18. QKVParallelLinear,
  19. RowParallelLinear)
  20. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  21. VocabParallelEmbedding)
  22. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  23. from aphrodite.multimodal.utils import (cached_get_tokenizer,
  24. repeat_and_pad_placeholder_tokens)
  25. from aphrodite.quantization import QuantizationConfig
  26. try:
  27. from xformers import ops as xops
  28. USE_XFORMERS_OPS = True
  29. except ImportError:
  30. USE_XFORMERS_OPS = False
  31. def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
  32. # Since interpolation is applied, the image size need not be divisible
  33. # assert image_size % patch_size == 0
  34. return image_size // patch_size
  35. def get_siglip_num_patches(*, image_size: int, patch_size: int) -> int:
  36. grid_length = get_siglip_patch_grid_length(image_size=image_size,
  37. patch_size=patch_size)
  38. return grid_length * grid_length
  39. def get_siglip_image_feature_size(hf_config: SiglipVisionConfig) -> int:
  40. return get_siglip_num_patches(image_size=hf_config.image_size,
  41. patch_size=hf_config.patch_size)
  42. def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int:
  43. return get_siglip_image_feature_size(hf_config)
  44. def dummy_seq_data_for_siglip(
  45. hf_config: SiglipVisionConfig,
  46. seq_len: int,
  47. num_images: int,
  48. *,
  49. image_token_id: int,
  50. image_feature_size_override: Optional[int] = None,
  51. ):
  52. if image_feature_size_override is None:
  53. image_feature_size = get_siglip_image_feature_size(hf_config)
  54. else:
  55. image_feature_size = image_feature_size_override
  56. token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  57. [image_token_id]) * image_feature_size
  58. token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  59. [0]) * (seq_len - image_feature_size)
  60. return SequenceData(token_ids)
  61. def dummy_image_for_siglip(
  62. hf_config: SiglipVisionConfig,
  63. num_images: int,
  64. *,
  65. image_width_override: Optional[int] = None,
  66. image_height_override: Optional[int] = None,
  67. ):
  68. width = height = hf_config.image_size
  69. if image_width_override is not None:
  70. width = image_width_override
  71. if image_height_override is not None:
  72. height = image_height_override
  73. image = Image.new("RGB", (width, height), color=0)
  74. return {"image": image if num_images == 1 else [image] * num_images}
  75. def input_processor_for_siglip(
  76. model_config: ModelConfig,
  77. hf_config: SiglipVisionConfig,
  78. llm_inputs: LLMInputs,
  79. *,
  80. image_token_id: int,
  81. image_feature_size_override: Optional[Union[int, List[int]]] = None,
  82. ):
  83. multi_modal_data = llm_inputs.get("multi_modal_data")
  84. if multi_modal_data is None or "image" not in multi_modal_data:
  85. return llm_inputs
  86. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  87. if image_feature_size_override is None:
  88. image_data = multi_modal_data["image"]
  89. if isinstance(image_data, Image.Image):
  90. image_feature_size = get_siglip_image_feature_size(hf_config)
  91. elif isinstance(image_data, torch.Tensor):
  92. num_images, image_feature_size, hidden_size = image_data.shape
  93. else:
  94. raise TypeError(f"Invalid image type: {type(image_data)}")
  95. else:
  96. image_feature_size = image_feature_size_override
  97. new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
  98. tokenizer,
  99. llm_inputs.get("prompt"),
  100. llm_inputs["prompt_token_ids"],
  101. placeholder_token_id=image_token_id,
  102. repeat_count=image_feature_size,
  103. )
  104. # NOTE: Create a defensive copy of the original inputs
  105. return LLMInputs(
  106. prompt_token_ids=new_token_ids,
  107. prompt=new_prompt,
  108. multi_modal_data=multi_modal_data,
  109. )
  110. # Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
  111. class SiglipVisionEmbeddings(nn.Module):
  112. def __init__(self, config: SiglipVisionConfig):
  113. super().__init__()
  114. self.config = config
  115. self.embed_dim = config.hidden_size
  116. self.image_size = config.image_size
  117. self.patch_size = config.patch_size
  118. self.patch_embedding = nn.Conv2d(
  119. in_channels=config.num_channels,
  120. out_channels=self.embed_dim,
  121. kernel_size=self.patch_size,
  122. stride=self.patch_size,
  123. padding="valid",
  124. )
  125. self.num_patches = (self.image_size // self.patch_size)**2
  126. self.num_positions = self.num_patches
  127. self.position_embedding = VocabParallelEmbedding(
  128. self.num_positions, self.embed_dim)
  129. self.register_buffer(
  130. "position_ids",
  131. torch.arange(self.num_positions, dtype=torch.int64).expand(
  132. (1, -1)),
  133. persistent=False,
  134. )
  135. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int,
  136. width: int) -> torch.Tensor:
  137. """
  138. This method is an adapted method for SigLIP (due to SigLIP not having
  139. class embedding unlike other ViTs) that allows the model to interpolate
  140. the pre-trained position encodings such that it can be usable on higher
  141. resolution images.
  142. Source:
  143. https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
  144. """
  145. position_embeddings = self.position_embedding.weight.unsqueeze(0)
  146. num_patches = embeddings.shape[1]
  147. num_positions = position_embeddings.shape[1]
  148. if num_patches == num_positions and height == width:
  149. return position_embeddings
  150. dim = embeddings.shape[-1]
  151. height = height // self.patch_size
  152. width = width // self.patch_size
  153. # we add a small number to avoid floating point error
  154. # in the interpolation
  155. # see discussion at https://github.com/facebookresearch/dino/issues/8
  156. height, width = height + 0.1, width + 0.1
  157. patch_pos_embed = position_embeddings.reshape(
  158. 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)),
  159. dim)
  160. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  161. patch_pos_embed = nn.functional.interpolate(
  162. patch_pos_embed,
  163. scale_factor=(
  164. height / math.sqrt(num_positions),
  165. width / math.sqrt(num_positions),
  166. ),
  167. mode="bicubic",
  168. align_corners=False,
  169. )
  170. if (int(height) != patch_pos_embed.shape[-2]
  171. or int(width) != patch_pos_embed.shape[-1]):
  172. raise ValueError("Width or height does not match with "
  173. "the interpolated position embeddings")
  174. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  175. return patch_pos_embed
  176. def forward(self,
  177. pixel_values: torch.Tensor,
  178. interpolate_pos_encoding: bool = False) -> torch.Tensor:
  179. _, _, height, width = pixel_values.shape
  180. target_dtype = self.patch_embedding.weight.dtype
  181. patch_embeds = self.patch_embedding(pixel_values.to(
  182. dtype=target_dtype)) # shape = [*, width, grid, grid]
  183. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  184. if interpolate_pos_encoding:
  185. embeddings = embeddings + self.interpolate_pos_encoding(
  186. embeddings, height, width)
  187. else:
  188. embeddings = embeddings + self.position_embedding(
  189. self.position_ids)
  190. return embeddings
  191. class SiglipParallelAttention(nn.Module):
  192. def __init__(
  193. self,
  194. config,
  195. quant_config: Optional[QuantizationConfig] = None,
  196. ):
  197. super().__init__()
  198. self.config = config
  199. self.embed_dim = config.hidden_size
  200. self.num_heads = config.num_attention_heads
  201. self.head_dim = self.embed_dim // self.num_heads
  202. if self.head_dim * self.num_heads != self.embed_dim:
  203. raise ValueError(f"embed_dim must be divisible by num_heads (got "
  204. f"`embed_dim`: {self.embed_dim} and `num_heads`:"
  205. f" {self.num_heads}).")
  206. self.scale = self.head_dim**-0.5
  207. self.dropout = config.attention_dropout
  208. self.qkv_proj = QKVParallelLinear(
  209. hidden_size=self.embed_dim,
  210. head_size=self.head_dim,
  211. total_num_heads=self.num_heads,
  212. quant_config=quant_config,
  213. )
  214. self.out_proj = RowParallelLinear(
  215. input_size=self.embed_dim,
  216. output_size=self.embed_dim,
  217. quant_config=quant_config,
  218. )
  219. self.tp_size = get_tensor_model_parallel_world_size()
  220. self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
  221. def forward(
  222. self,
  223. hidden_states: torch.Tensor,
  224. ) -> torch.Tensor:
  225. """Input shape: Batch x Time x Channel"""
  226. batch_size, q_len, _ = hidden_states.size()
  227. qkv_states, _ = self.qkv_proj(hidden_states)
  228. query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
  229. query_states = query_states.view(batch_size, q_len,
  230. self.num_heads_per_partition,
  231. self.head_dim)
  232. key_states = key_states.view(batch_size, q_len,
  233. self.num_heads_per_partition,
  234. self.head_dim)
  235. value_states = value_states.view(batch_size, q_len,
  236. self.num_heads_per_partition,
  237. self.head_dim)
  238. out = xops.memory_efficient_attention_forward(query_states,
  239. key_states,
  240. value_states,
  241. p=self.dropout,
  242. scale=self.scale)
  243. out = out.view(batch_size, q_len, -1)
  244. attn_output, _ = self.out_proj(out)
  245. return attn_output, None
  246. class SiglipMLP(nn.Module):
  247. def __init__(
  248. self,
  249. config,
  250. quant_config: Optional[QuantizationConfig] = None,
  251. ):
  252. super().__init__()
  253. self.config = config
  254. self.activation_fn = get_act_fn(config.hidden_act)
  255. # For quantization, we require the hidden size to be a multiple of 64
  256. quantizable = (config.hidden_size % 64 == 0
  257. and config.intermediate_size % 64 == 0)
  258. self.fc1 = ColumnParallelLinear(
  259. config.hidden_size,
  260. config.intermediate_size,
  261. quant_config=quant_config if quantizable else None,
  262. )
  263. self.fc2 = RowParallelLinear(
  264. config.intermediate_size,
  265. config.hidden_size,
  266. quant_config=quant_config if quantizable else None,
  267. )
  268. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  269. hidden_states, _ = self.fc1(hidden_states)
  270. hidden_states = self.activation_fn(hidden_states)
  271. hidden_states, _ = self.fc2(hidden_states)
  272. return hidden_states
  273. class SiglipEncoderLayer(nn.Module):
  274. def __init__(
  275. self,
  276. config: SiglipVisionConfig,
  277. quant_config: Optional[QuantizationConfig] = None,
  278. ):
  279. super().__init__()
  280. self.embed_dim = config.hidden_size
  281. num_heads = config.num_attention_heads
  282. tp_size = get_tensor_model_parallel_world_size()
  283. if USE_XFORMERS_OPS and num_heads % tp_size == 0:
  284. self.self_attn = SiglipParallelAttention(config,
  285. quant_config=quant_config)
  286. else:
  287. self.self_attn = SiglipSdpaAttention(config)
  288. self.layer_norm1 = nn.LayerNorm(self.embed_dim,
  289. eps=config.layer_norm_eps)
  290. self.mlp = SiglipMLP(
  291. config,
  292. quant_config=quant_config,
  293. )
  294. self.layer_norm2 = nn.LayerNorm(self.embed_dim,
  295. eps=config.layer_norm_eps)
  296. def forward(
  297. self,
  298. hidden_states: torch.Tensor,
  299. ) -> Tuple[torch.Tensor, None]:
  300. residual = hidden_states
  301. hidden_states = self.layer_norm1(hidden_states)
  302. hidden_states, _ = self.self_attn(hidden_states=hidden_states)
  303. hidden_states = residual + hidden_states
  304. residual = hidden_states
  305. hidden_states = self.layer_norm2(hidden_states)
  306. hidden_states = self.mlp(hidden_states)
  307. hidden_states = residual + hidden_states
  308. return hidden_states, None
  309. class SiglipEncoder(nn.Module):
  310. def __init__(
  311. self,
  312. config: SiglipVisionConfig,
  313. quant_config: Optional[QuantizationConfig] = None,
  314. num_hidden_layers_override: Optional[int] = None,
  315. ):
  316. super().__init__()
  317. self.config = config
  318. if num_hidden_layers_override is None:
  319. num_hidden_layers = config.num_hidden_layers
  320. else:
  321. num_hidden_layers = num_hidden_layers_override
  322. self.layers = nn.ModuleList([
  323. SiglipEncoderLayer(config, quant_config=quant_config)
  324. for _ in range(num_hidden_layers)
  325. ])
  326. def forward(
  327. self,
  328. inputs_embeds: torch.Tensor,
  329. ) -> torch.Tensor:
  330. hidden_states = inputs_embeds
  331. for encoder_layer in self.layers:
  332. hidden_states, _ = encoder_layer(hidden_states)
  333. return hidden_states
  334. class SiglipMultiheadAttentionPoolingHead(nn.Module):
  335. """Multihead Attention Pooling."""
  336. def __init__(
  337. self,
  338. config: SiglipVisionConfig,
  339. quant_config: Optional[QuantizationConfig] = None,
  340. ):
  341. super().__init__()
  342. self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  343. # TODO(ChristopherCho): Implement aphrodite version of MHA
  344. self.attention = torch.nn.MultiheadAttention(
  345. config.hidden_size, config.num_attention_heads, batch_first=True)
  346. self.layernorm = nn.LayerNorm(config.hidden_size,
  347. eps=config.layer_norm_eps)
  348. self.mlp = SiglipMLP(config=config, quant_config=quant_config)
  349. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  350. batch_size = hidden_state.shape[0]
  351. probe = self.probe.repeat(batch_size, 1, 1)
  352. hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
  353. residual = hidden_state
  354. hidden_state = self.layernorm(hidden_state)
  355. hidden_state = residual + self.mlp(hidden_state)
  356. return hidden_state[:, 0]
  357. class SiglipVisionTransformer(nn.Module):
  358. def __init__(
  359. self,
  360. config: SiglipVisionConfig,
  361. quant_config: Optional[QuantizationConfig] = None,
  362. num_hidden_layers_override: Optional[int] = None,
  363. ):
  364. super().__init__()
  365. self.config = config
  366. embed_dim = config.hidden_size
  367. self.embeddings = SiglipVisionEmbeddings(config)
  368. self.encoder = SiglipEncoder(
  369. config,
  370. quant_config=quant_config,
  371. num_hidden_layers_override=num_hidden_layers_override,
  372. )
  373. if len(self.encoder.layers) > config.num_hidden_layers:
  374. raise ValueError(
  375. f"The original encoder only has {config.num_hidden_layers} "
  376. f"layers, but you requested {len(self.encoder.layers)} layers."
  377. )
  378. elif len(self.encoder.layers) == config.num_hidden_layers:
  379. self.post_layernorm = nn.LayerNorm(embed_dim,
  380. eps=config.layer_norm_eps)
  381. else:
  382. # post_layernorm is unused when we extract intermediate features
  383. # In this case, we can skip it to conserve memory
  384. self.post_layernorm = None
  385. self.use_head = (True if not hasattr(config, "vision_use_head") else
  386. config.vision_use_head)
  387. if self.use_head:
  388. self.head = SiglipMultiheadAttentionPoolingHead(
  389. config=config, quant_config=quant_config)
  390. def forward(
  391. self,
  392. pixel_values: torch.Tensor,
  393. interpolate_pos_encoding: bool = True,
  394. ) -> torch.Tensor:
  395. hidden_states = self.embeddings(
  396. pixel_values,
  397. interpolate_pos_encoding=interpolate_pos_encoding,
  398. )
  399. encoder_outputs = self.encoder(inputs_embeds=hidden_states)
  400. if self.post_layernorm is None:
  401. return encoder_outputs
  402. last_hidden_state = self.post_layernorm(encoder_outputs)
  403. # TODO: add this back when pooled_output is used in inference
  404. # if self.use_head:
  405. # pooled_output = self.head(last_hidden_state)
  406. return last_hidden_state
  407. class SiglipVisionModel(nn.Module):
  408. config_class = SiglipVisionConfig
  409. main_input_name = "pixel_values"
  410. def __init__(
  411. self,
  412. config: SiglipVisionConfig,
  413. quant_config: Optional[QuantizationConfig] = None,
  414. num_hidden_layers_override: Optional[int] = None,
  415. ):
  416. super().__init__()
  417. num_heads = config.num_attention_heads
  418. tp_size = get_tensor_model_parallel_world_size()
  419. self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
  420. self.vision_model = SiglipVisionTransformer(
  421. config,
  422. quant_config,
  423. num_hidden_layers_override=num_hidden_layers_override,
  424. )
  425. @property
  426. def _require_post_layernorm(self) -> bool:
  427. return self.vision_model.post_layernorm is not None
  428. def get_input_embeddings(self) -> nn.Module:
  429. return self.vision_model.embeddings.patch_embedding
  430. def forward(
  431. self,
  432. pixel_values: torch.Tensor,
  433. interpolate_pos_encoding: bool = False,
  434. ) -> torch.Tensor:
  435. return self.vision_model(
  436. pixel_values=pixel_values,
  437. interpolate_pos_encoding=interpolate_pos_encoding,
  438. )
  439. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  440. stacked_params_mapping = [
  441. # (param_name, shard_name, shard_id)
  442. ("qkv_proj", "q_proj", "q"),
  443. ("qkv_proj", "k_proj", "k"),
  444. ("qkv_proj", "v_proj", "v"),
  445. ] if self.shard_weight else []
  446. params_dict = dict(self.named_parameters())
  447. layer_count = len(self.vision_model.encoder.layers)
  448. for name, loaded_weight in weights:
  449. # post_layernorm is optional in SiglipVisionModel
  450. if ("vision_model.post_layernorm" in name
  451. and not self._require_post_layernorm):
  452. continue
  453. # omit layers when num_hidden_layers_override is set
  454. if "vision_model.encoder.layers." in name:
  455. layer_idx = int(name.split(".")[3])
  456. if layer_idx >= layer_count:
  457. continue
  458. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  459. if weight_name not in name:
  460. continue
  461. param = params_dict[name.replace(weight_name, param_name)]
  462. weight_loader = param.weight_loader
  463. weight_loader(param, loaded_weight, shard_id)
  464. break
  465. else:
  466. param = params_dict[name]
  467. weight_loader = getattr(param, "weight_loader",
  468. default_weight_loader)
  469. weight_loader(param, loaded_weight)