siglip.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509
  1. """Implementation of SiglipVisionModel intended to be only used
  2. within a vision language model."""
  3. import math
  4. from array import array
  5. from typing import Iterable, List, Optional, Tuple, Union
  6. import torch
  7. from PIL import Image
  8. from torch import nn
  9. from transformers import SiglipVisionConfig
  10. from xformers import ops as xops
  11. from aphrodite.common.config import ModelConfig
  12. from aphrodite.common.sequence import SequenceData
  13. from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
  14. from aphrodite.distributed import divide, get_tensor_model_parallel_world_size
  15. from aphrodite.inputs import LLMInputs
  16. from aphrodite.modeling.layers.activation import get_act_fn
  17. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  18. QKVParallelLinear,
  19. RowParallelLinear)
  20. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  21. VocabParallelEmbedding)
  22. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  23. from aphrodite.multimodal.utils import (cached_get_tokenizer,
  24. repeat_and_pad_placeholder_tokens)
  25. from aphrodite.quantization import QuantizationConfig
  26. def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
  27. # Since interpolation is applied, the image size need not be divisible
  28. # assert image_size % patch_size == 0
  29. return image_size // patch_size
  30. def get_siglip_num_patches(*, image_size: int, patch_size: int) -> int:
  31. grid_length = get_siglip_patch_grid_length(image_size=image_size,
  32. patch_size=patch_size)
  33. return grid_length * grid_length
  34. def get_siglip_image_feature_size(hf_config: SiglipVisionConfig) -> int:
  35. return get_siglip_num_patches(image_size=hf_config.image_size,
  36. patch_size=hf_config.patch_size)
  37. def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int:
  38. return get_siglip_image_feature_size(hf_config)
  39. def dummy_seq_data_for_siglip(
  40. hf_config: SiglipVisionConfig,
  41. seq_len: int,
  42. num_images: int,
  43. *,
  44. image_token_id: int,
  45. image_feature_size_override: Optional[int] = None,
  46. ):
  47. if image_feature_size_override is None:
  48. image_feature_size = get_siglip_image_feature_size(hf_config)
  49. else:
  50. image_feature_size = image_feature_size_override
  51. token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  52. [image_token_id]) * image_feature_size
  53. token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  54. [0]) * (seq_len - image_feature_size)
  55. return SequenceData(token_ids)
  56. def dummy_image_for_siglip(
  57. hf_config: SiglipVisionConfig,
  58. num_images: int,
  59. *,
  60. image_width_override: Optional[int] = None,
  61. image_height_override: Optional[int] = None,
  62. ):
  63. width = height = hf_config.image_size
  64. if image_width_override is not None:
  65. width = image_width_override
  66. if image_height_override is not None:
  67. height = image_height_override
  68. image = Image.new("RGB", (width, height), color=0)
  69. return {"image": image if num_images == 1 else [image] * num_images}
  70. def input_processor_for_siglip(
  71. model_config: ModelConfig,
  72. hf_config: SiglipVisionConfig,
  73. llm_inputs: LLMInputs,
  74. *,
  75. image_token_id: int,
  76. image_feature_size_override: Optional[Union[int, List[int]]] = None,
  77. ):
  78. multi_modal_data = llm_inputs.get("multi_modal_data")
  79. if multi_modal_data is None or "image" not in multi_modal_data:
  80. return llm_inputs
  81. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  82. if image_feature_size_override is None:
  83. image_data = multi_modal_data["image"]
  84. if isinstance(image_data, Image.Image):
  85. image_feature_size = get_siglip_image_feature_size(hf_config)
  86. elif isinstance(image_data, torch.Tensor):
  87. image_feature_size = image_data.shape[0]
  88. else:
  89. raise TypeError(f"Invalid image type: {type(image_data)}")
  90. else:
  91. image_feature_size = image_feature_size_override
  92. new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
  93. tokenizer,
  94. llm_inputs.get("prompt"),
  95. llm_inputs["prompt_token_ids"],
  96. placeholder_token_id=image_token_id,
  97. repeat_count=image_feature_size,
  98. )
  99. # NOTE: Create a defensive copy of the original inputs
  100. return LLMInputs(
  101. prompt_token_ids=new_token_ids,
  102. prompt=new_prompt,
  103. multi_modal_data=multi_modal_data,
  104. )
  105. # Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
  106. class SiglipVisionEmbeddings(nn.Module):
  107. def __init__(self, config: SiglipVisionConfig):
  108. super().__init__()
  109. self.config = config
  110. self.embed_dim = config.hidden_size
  111. self.image_size = config.image_size
  112. self.patch_size = config.patch_size
  113. self.patch_embedding = nn.Conv2d(
  114. in_channels=config.num_channels,
  115. out_channels=self.embed_dim,
  116. kernel_size=self.patch_size,
  117. stride=self.patch_size,
  118. padding="valid",
  119. )
  120. self.num_patches = (self.image_size // self.patch_size)**2
  121. self.num_positions = self.num_patches
  122. self.position_embedding = VocabParallelEmbedding(
  123. self.num_positions, self.embed_dim)
  124. self.register_buffer(
  125. "position_ids",
  126. torch.arange(self.num_positions, dtype=torch.int64).expand(
  127. (1, -1)),
  128. persistent=False,
  129. )
  130. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int,
  131. width: int) -> torch.Tensor:
  132. """
  133. This method is an adapted method for SigLIP (due to SigLIP not having
  134. class embedding unlike other ViTs) that allows the model to interpolate
  135. the pre-trained position encodings such that it can be usable on higher
  136. resolution images.
  137. Source:
  138. https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
  139. """
  140. position_embeddings = self.position_embedding.weight.unsqueeze(0)
  141. num_patches = embeddings.shape[1]
  142. num_positions = position_embeddings.shape[1]
  143. if num_patches == num_positions and height == width:
  144. return position_embeddings
  145. dim = embeddings.shape[-1]
  146. height = height // self.patch_size
  147. width = width // self.patch_size
  148. # we add a small number to avoid floating point error
  149. # in the interpolation
  150. # see discussion at https://github.com/facebookresearch/dino/issues/8
  151. height, width = height + 0.1, width + 0.1
  152. patch_pos_embed = position_embeddings.reshape(
  153. 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)),
  154. dim)
  155. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  156. patch_pos_embed = nn.functional.interpolate(
  157. patch_pos_embed,
  158. scale_factor=(
  159. height / math.sqrt(num_positions),
  160. width / math.sqrt(num_positions),
  161. ),
  162. mode="bicubic",
  163. align_corners=False,
  164. )
  165. if (int(height) != patch_pos_embed.shape[-2]
  166. or int(width) != patch_pos_embed.shape[-1]):
  167. raise ValueError("Width or height does not match with "
  168. "the interpolated position embeddings")
  169. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  170. return patch_pos_embed
  171. def forward(self,
  172. pixel_values: torch.Tensor,
  173. interpolate_pos_encoding: bool = False) -> torch.Tensor:
  174. _, _, height, width = pixel_values.shape
  175. target_dtype = self.patch_embedding.weight.dtype
  176. patch_embeds = self.patch_embedding(pixel_values.to(
  177. dtype=target_dtype)) # shape = [*, width, grid, grid]
  178. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  179. if interpolate_pos_encoding:
  180. embeddings = embeddings + self.interpolate_pos_encoding(
  181. embeddings, height, width)
  182. else:
  183. embeddings = embeddings + self.position_embedding(
  184. self.position_ids)
  185. return embeddings
  186. class SiglipAttention(nn.Module):
  187. def __init__(
  188. self,
  189. config,
  190. quant_config: Optional[QuantizationConfig] = None,
  191. ):
  192. super().__init__()
  193. self.config = config
  194. self.embed_dim = config.hidden_size
  195. self.num_heads = config.num_attention_heads
  196. self.head_dim = self.embed_dim // self.num_heads
  197. if self.head_dim * self.num_heads != self.embed_dim:
  198. raise ValueError(f"embed_dim must be divisible by num_heads (got "
  199. f"`embed_dim`: {self.embed_dim} and `num_heads`:"
  200. f" {self.num_heads}).")
  201. self.scale = self.head_dim**-0.5
  202. self.dropout = config.attention_dropout
  203. self.qkv_proj = QKVParallelLinear(
  204. hidden_size=self.embed_dim,
  205. head_size=self.head_dim,
  206. total_num_heads=self.num_heads,
  207. quant_config=quant_config,
  208. )
  209. self.out_proj = RowParallelLinear(
  210. input_size=self.embed_dim,
  211. output_size=self.embed_dim,
  212. quant_config=quant_config,
  213. )
  214. self.tp_size = get_tensor_model_parallel_world_size()
  215. self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
  216. def forward(
  217. self,
  218. hidden_states: torch.Tensor,
  219. ) -> torch.Tensor:
  220. """Input shape: Batch x Time x Channel"""
  221. batch_size, q_len, _ = hidden_states.size()
  222. qkv_states, _ = self.qkv_proj(hidden_states)
  223. query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
  224. query_states = query_states.view(batch_size, q_len,
  225. self.num_heads_per_partition,
  226. self.head_dim)
  227. key_states = key_states.view(batch_size, q_len,
  228. self.num_heads_per_partition,
  229. self.head_dim)
  230. value_states = value_states.view(batch_size, q_len,
  231. self.num_heads_per_partition,
  232. self.head_dim)
  233. out = xops.memory_efficient_attention_forward(query_states,
  234. key_states,
  235. value_states,
  236. p=self.dropout,
  237. scale=self.scale)
  238. out = out.view(batch_size, q_len, -1)
  239. attn_output, _ = self.out_proj(out)
  240. return attn_output
  241. class SiglipMLP(nn.Module):
  242. def __init__(
  243. self,
  244. config,
  245. quant_config: Optional[QuantizationConfig] = None,
  246. ):
  247. super().__init__()
  248. self.config = config
  249. self.activation_fn = get_act_fn(config.hidden_act)
  250. # For quantization, we require the hidden size to be a multiple of 64
  251. quantizable = (config.hidden_size % 64 == 0
  252. and config.intermediate_size % 64 == 0)
  253. self.fc1 = ColumnParallelLinear(
  254. config.hidden_size,
  255. config.intermediate_size,
  256. quant_config=quant_config if quantizable else None,
  257. )
  258. self.fc2 = RowParallelLinear(
  259. config.intermediate_size,
  260. config.hidden_size,
  261. quant_config=quant_config if quantizable else None,
  262. )
  263. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  264. hidden_states, _ = self.fc1(hidden_states)
  265. hidden_states = self.activation_fn(hidden_states)
  266. hidden_states, _ = self.fc2(hidden_states)
  267. return hidden_states
  268. class SiglipEncoderLayer(nn.Module):
  269. def __init__(
  270. self,
  271. config: SiglipVisionConfig,
  272. quant_config: Optional[QuantizationConfig] = None,
  273. ):
  274. super().__init__()
  275. self.embed_dim = config.hidden_size
  276. self.self_attn = SiglipAttention(config, quant_config=quant_config)
  277. self.layer_norm1 = nn.LayerNorm(self.embed_dim,
  278. eps=config.layer_norm_eps)
  279. self.mlp = SiglipMLP(
  280. config,
  281. quant_config=quant_config,
  282. )
  283. self.layer_norm2 = nn.LayerNorm(self.embed_dim,
  284. eps=config.layer_norm_eps)
  285. def forward(
  286. self,
  287. hidden_states: torch.Tensor,
  288. ) -> Tuple[torch.Tensor, None]:
  289. residual = hidden_states
  290. hidden_states = self.layer_norm1(hidden_states)
  291. hidden_states = self.self_attn(hidden_states=hidden_states)
  292. hidden_states = residual + hidden_states
  293. residual = hidden_states
  294. hidden_states = self.layer_norm2(hidden_states)
  295. hidden_states = self.mlp(hidden_states)
  296. hidden_states = residual + hidden_states
  297. return hidden_states, None
  298. class SiglipEncoder(nn.Module):
  299. def __init__(
  300. self,
  301. config: SiglipVisionConfig,
  302. quant_config: Optional[QuantizationConfig] = None,
  303. num_hidden_layers_override: Optional[int] = None,
  304. ):
  305. super().__init__()
  306. self.config = config
  307. if num_hidden_layers_override is None:
  308. num_hidden_layers = config.num_hidden_layers
  309. else:
  310. num_hidden_layers = num_hidden_layers_override
  311. self.layers = nn.ModuleList([
  312. SiglipEncoderLayer(config, quant_config=quant_config)
  313. for _ in range(num_hidden_layers)
  314. ])
  315. def forward(
  316. self,
  317. inputs_embeds: torch.Tensor,
  318. ) -> torch.Tensor:
  319. hidden_states = inputs_embeds
  320. for encoder_layer in self.layers:
  321. hidden_states, _ = encoder_layer(hidden_states)
  322. return hidden_states
  323. class SiglipMultiheadAttentionPoolingHead(nn.Module):
  324. """Multihead Attention Pooling."""
  325. def __init__(
  326. self,
  327. config: SiglipVisionConfig,
  328. quant_config: Optional[QuantizationConfig] = None,
  329. ):
  330. super().__init__()
  331. self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  332. # TODO(ChristopherCho): Implement aphrodite version of MHA
  333. self.attention = torch.nn.MultiheadAttention(
  334. config.hidden_size, config.num_attention_heads, batch_first=True)
  335. self.layernorm = nn.LayerNorm(config.hidden_size,
  336. eps=config.layer_norm_eps)
  337. self.mlp = SiglipMLP(config=config, quant_config=quant_config)
  338. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  339. batch_size = hidden_state.shape[0]
  340. probe = self.probe.repeat(batch_size, 1, 1)
  341. hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
  342. residual = hidden_state
  343. hidden_state = self.layernorm(hidden_state)
  344. hidden_state = residual + self.mlp(hidden_state)
  345. return hidden_state[:, 0]
  346. class SiglipVisionTransformer(nn.Module):
  347. def __init__(
  348. self,
  349. config: SiglipVisionConfig,
  350. quant_config: Optional[QuantizationConfig] = None,
  351. num_hidden_layers_override: Optional[int] = None,
  352. ):
  353. super().__init__()
  354. self.config = config
  355. embed_dim = config.hidden_size
  356. self.embeddings = SiglipVisionEmbeddings(config)
  357. self.encoder = SiglipEncoder(
  358. config,
  359. quant_config=quant_config,
  360. num_hidden_layers_override=num_hidden_layers_override,
  361. )
  362. self.post_layernorm = nn.LayerNorm(embed_dim,
  363. eps=config.layer_norm_eps)
  364. self.use_head = (True if not hasattr(config, "vision_use_head") else
  365. config.vision_use_head)
  366. if self.use_head:
  367. self.head = SiglipMultiheadAttentionPoolingHead(
  368. config=config, quant_config=quant_config)
  369. def forward(
  370. self,
  371. pixel_values: torch.Tensor,
  372. interpolate_pos_encoding: bool = True,
  373. ) -> torch.Tensor:
  374. hidden_states = self.embeddings(
  375. pixel_values,
  376. interpolate_pos_encoding=interpolate_pos_encoding,
  377. )
  378. encoder_outputs = self.encoder(inputs_embeds=hidden_states)
  379. last_hidden_state = self.post_layernorm(encoder_outputs)
  380. # TODO: add this back when pooled_output is used in inference
  381. # if self.use_head:
  382. # pooled_output = self.head(last_hidden_state)
  383. return last_hidden_state
  384. class SiglipVisionModel(nn.Module):
  385. config_class = SiglipVisionConfig
  386. main_input_name = "pixel_values"
  387. def __init__(
  388. self,
  389. config: SiglipVisionConfig,
  390. quant_config: Optional[QuantizationConfig] = None,
  391. num_hidden_layers_override: Optional[int] = None,
  392. ):
  393. super().__init__()
  394. self.vision_model = SiglipVisionTransformer(
  395. config,
  396. quant_config,
  397. num_hidden_layers_override=num_hidden_layers_override,
  398. )
  399. def get_input_embeddings(self) -> nn.Module:
  400. return self.vision_model.embeddings.patch_embedding
  401. def forward(
  402. self,
  403. pixel_values: torch.Tensor,
  404. interpolate_pos_encoding: bool = False,
  405. ) -> torch.Tensor:
  406. return self.vision_model(
  407. pixel_values=pixel_values,
  408. interpolate_pos_encoding=interpolate_pos_encoding,
  409. )
  410. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  411. params_dict = dict(self.named_parameters())
  412. layer_count = len(self.vision_model.encoder.layers)
  413. for name, loaded_weight in weights:
  414. # omit layers when num_hidden_layers_override is set
  415. if "vision_model.encoder.layers." in name:
  416. layer_idx = int(name.split(".")[3])
  417. if layer_idx >= layer_count:
  418. continue
  419. param = params_dict[name]
  420. weight_loader = getattr(param, "weight_loader",
  421. default_weight_loader)
  422. weight_loader(param, loaded_weight)