blip.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. """Minimal implementation of BlipVisionModel intended to be only used
  2. within a vision language model."""
  3. from array import array
  4. from typing import Optional, Union
  5. import torch
  6. import torch.nn as nn
  7. from PIL import Image
  8. from transformers import Blip2VisionConfig, BlipVisionConfig
  9. from transformers.models.blip.modeling_blip import BlipAttention
  10. from aphrodite.common.config import ModelConfig
  11. from aphrodite.common.sequence import SequenceData
  12. from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
  13. from aphrodite.distributed import divide, get_tensor_model_parallel_world_size
  14. from aphrodite.inputs import LLMInputs
  15. from aphrodite.modeling.layers.activation import get_act_fn
  16. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  17. QKVParallelLinear,
  18. RowParallelLinear)
  19. from aphrodite.multimodal.utils import (cached_get_tokenizer,
  20. repeat_and_pad_placeholder_tokens)
  21. from aphrodite.quantization import QuantizationConfig
  22. try:
  23. from xformers import ops as xops
  24. USE_XFORMERS_OPS = True
  25. except ImportError:
  26. USE_XFORMERS_OPS = False
  27. def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
  28. assert image_size % patch_size == 0
  29. return image_size // patch_size
  30. def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
  31. grid_length = get_blip_patch_grid_length(image_size=image_size,
  32. patch_size=patch_size)
  33. return grid_length * grid_length
  34. def get_blip_image_feature_size(
  35. hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
  36. return get_blip_num_patches(image_size=hf_config.image_size,
  37. patch_size=hf_config.patch_size)
  38. def get_max_blip_image_tokens(
  39. hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
  40. return get_blip_image_feature_size(hf_config)
  41. def dummy_seq_data_for_blip(
  42. hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
  43. seq_len: int,
  44. *,
  45. image_token_id: int,
  46. image_feature_size_override: Optional[int] = None,
  47. ):
  48. if image_feature_size_override is None:
  49. image_feature_size = get_blip_image_feature_size(hf_config)
  50. else:
  51. image_feature_size = image_feature_size_override
  52. token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  53. [image_token_id]) * image_feature_size
  54. token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  55. [0]) * (seq_len - image_feature_size)
  56. return SequenceData(token_ids)
  57. def dummy_image_for_blip(
  58. hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
  59. num_images: int,
  60. *,
  61. image_width_override: Optional[int] = None,
  62. image_height_override: Optional[int] = None,
  63. ):
  64. width = height = hf_config.image_size
  65. if image_width_override is not None:
  66. width = image_width_override
  67. if image_height_override is not None:
  68. height = image_height_override
  69. image = Image.new("RGB", (width, height), color=0)
  70. return {"image": image if num_images == 1 else [image] * num_images}
  71. def input_processor_for_blip(
  72. model_config: ModelConfig,
  73. hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
  74. llm_inputs: LLMInputs,
  75. *,
  76. image_token_id: int,
  77. image_feature_size_override: Optional[int] = None,
  78. ):
  79. multi_modal_data = llm_inputs.get("multi_modal_data")
  80. if multi_modal_data is None or "image" not in multi_modal_data:
  81. return llm_inputs
  82. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  83. if image_feature_size_override is None:
  84. image_feature_size = get_blip_image_feature_size(hf_config)
  85. else:
  86. image_feature_size = image_feature_size_override
  87. new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
  88. tokenizer,
  89. llm_inputs.get("prompt"),
  90. llm_inputs["prompt_token_ids"],
  91. placeholder_token_id=image_token_id,
  92. repeat_count=image_feature_size,
  93. )
  94. # NOTE: Create a defensive copy of the original inputs
  95. return LLMInputs(prompt_token_ids=new_token_ids,
  96. prompt=new_prompt,
  97. multi_modal_data=multi_modal_data)
  98. # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
  99. class BlipVisionEmbeddings(nn.Module):
  100. def __init__(self, config: BlipVisionConfig):
  101. super().__init__()
  102. self.config = config
  103. self.embed_dim = config.hidden_size
  104. self.image_size = config.image_size
  105. self.patch_size = config.patch_size
  106. self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
  107. self.patch_embedding = nn.Conv2d(
  108. in_channels=3,
  109. out_channels=self.embed_dim,
  110. kernel_size=self.patch_size,
  111. stride=self.patch_size,
  112. )
  113. self.num_patches = get_blip_num_patches(image_size=self.image_size,
  114. patch_size=self.patch_size)
  115. self.num_positions = self.num_patches + 1
  116. self.position_embedding = nn.Parameter(
  117. torch.randn(1, self.num_positions, self.embed_dim))
  118. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  119. batch_size = pixel_values.shape[0]
  120. target_dtype = self.patch_embedding.weight.dtype
  121. patch_embeds = self.patch_embedding(pixel_values.to(
  122. dtype=target_dtype)) # shape = [*, width, grid, grid]
  123. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  124. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  125. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  126. position_embeds = self.position_embedding.to(target_dtype)
  127. embeddings = embeddings + position_embeds[:, :embeddings.size(1), :]
  128. return embeddings
  129. class BlipParallelAttention(nn.Module):
  130. """Multi-headed attention from 'Attention Is All You Need' paper"""
  131. def __init__(
  132. self,
  133. config: BlipVisionConfig,
  134. quant_config: Optional[QuantizationConfig] = None,
  135. ):
  136. super().__init__()
  137. self.config = config
  138. self.embed_dim = config.hidden_size
  139. self.num_heads = config.num_attention_heads
  140. self.head_dim = self.embed_dim // self.num_heads
  141. if self.head_dim * self.num_heads != self.embed_dim:
  142. raise ValueError(
  143. "embed_dim must be divisible by num_heads "
  144. f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
  145. f" {self.num_heads}).")
  146. self.scale = self.head_dim**-0.5
  147. self.dropout = config.attention_dropout
  148. self.qkv = QKVParallelLinear(
  149. self.embed_dim,
  150. self.head_dim,
  151. self.num_heads,
  152. bias=config.qkv_bias,
  153. quant_config=quant_config,
  154. )
  155. self.projection = RowParallelLinear(
  156. self.embed_dim,
  157. self.embed_dim,
  158. quant_config=quant_config,
  159. )
  160. self.tp_size = get_tensor_model_parallel_world_size()
  161. self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
  162. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  163. return tensor.view(bsz, seq_len, self.num_heads,
  164. self.head_dim).transpose(1, 2).contiguous()
  165. def forward(
  166. self,
  167. hidden_states: torch.Tensor,
  168. ):
  169. """Input shape: Batch x Time x Channel"""
  170. bsz, tgt_len, _ = hidden_states.size()
  171. qkv_states, _ = self.qkv(hidden_states)
  172. query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
  173. query_states = query_states.view(bsz, tgt_len,
  174. self.num_heads_per_partition,
  175. self.head_dim)
  176. key_states = key_states.view(bsz, tgt_len,
  177. self.num_heads_per_partition,
  178. self.head_dim)
  179. value_states = value_states.view(bsz, tgt_len,
  180. self.num_heads_per_partition,
  181. self.head_dim)
  182. out = xops.memory_efficient_attention_forward(query_states,
  183. key_states,
  184. value_states,
  185. p=self.dropout,
  186. scale=self.scale)
  187. out = out.view(bsz, tgt_len, -1)
  188. attn_output, _ = self.projection(out)
  189. return attn_output, None
  190. class BlipMLP(nn.Module):
  191. def __init__(self,
  192. config: BlipVisionConfig,
  193. quant_config: Optional[QuantizationConfig] = None):
  194. super().__init__()
  195. self.config = config
  196. self.activation_fn = get_act_fn(config.hidden_act)
  197. self.fc1 = ColumnParallelLinear(config.hidden_size,
  198. config.intermediate_size,
  199. bias=True,
  200. quant_config=quant_config)
  201. self.fc2 = RowParallelLinear(config.intermediate_size,
  202. config.hidden_size,
  203. bias=True,
  204. quant_config=quant_config)
  205. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  206. hidden_states, _ = self.fc1(hidden_states)
  207. hidden_states = self.activation_fn(hidden_states)
  208. hidden_states, _ = self.fc2(hidden_states)
  209. return hidden_states
  210. class BlipEncoderLayer(nn.Module):
  211. def __init__(self,
  212. config: BlipVisionConfig,
  213. quant_config: Optional[QuantizationConfig] = None):
  214. super().__init__()
  215. # fallback to sdpa attention if tp unavailable
  216. num_heads = config.num_attention_heads
  217. tp_size = get_tensor_model_parallel_world_size()
  218. if USE_XFORMERS_OPS and num_heads % tp_size == 0:
  219. self.self_attn = BlipParallelAttention(config,
  220. quant_config=quant_config)
  221. else:
  222. # Blip doesn't have SDPA attention implemented in transformers
  223. # use eager attention instead for cpu backend
  224. self.self_attn = BlipAttention(config)
  225. self.layer_norm1 = nn.LayerNorm(config.hidden_size,
  226. eps=config.layer_norm_eps)
  227. self.mlp = BlipMLP(config, quant_config=quant_config)
  228. self.layer_norm2 = nn.LayerNorm(config.hidden_size,
  229. eps=config.layer_norm_eps)
  230. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  231. residual = hidden_states
  232. hidden_states = self.layer_norm1(hidden_states)
  233. hidden_states, _ = self.self_attn(hidden_states=hidden_states)
  234. hidden_states = residual + hidden_states
  235. residual = hidden_states
  236. hidden_states = self.layer_norm2(hidden_states)
  237. hidden_states = self.mlp(hidden_states)
  238. hidden_states = residual + hidden_states
  239. return hidden_states
  240. class BlipEncoder(nn.Module):
  241. """
  242. Transformer encoder consisting of `config.num_hidden_layers` self
  243. attention layers. Each layer is a [`BlipEncoderLayer`].
  244. Args:
  245. config: BlipConfig
  246. """
  247. def __init__(self,
  248. config: BlipVisionConfig,
  249. quant_config: Optional[QuantizationConfig] = None,
  250. num_hidden_layers_override: Optional[int] = None):
  251. super().__init__()
  252. self.config = config
  253. if num_hidden_layers_override is None:
  254. num_hidden_layers = config.num_hidden_layers
  255. else:
  256. num_hidden_layers = num_hidden_layers_override
  257. self.layers = nn.ModuleList([
  258. BlipEncoderLayer(config=config, quant_config=quant_config)
  259. for _ in range(num_hidden_layers)
  260. ])
  261. def forward(self, inputs_embeds: torch.Tensor):
  262. hidden_states = inputs_embeds
  263. for encoder_layer in self.layers:
  264. hidden_states = encoder_layer(hidden_states)
  265. return hidden_states
  266. class BlipVisionModel(nn.Module):
  267. config_class = BlipVisionConfig
  268. main_input_name = "pixel_values"
  269. def __init__(self,
  270. config: BlipVisionConfig,
  271. quant_config: Optional[QuantizationConfig] = None,
  272. num_hidden_layers_override: Optional[int] = None):
  273. super().__init__()
  274. self.config = config
  275. self.embeddings = BlipVisionEmbeddings(config)
  276. self.encoder = BlipEncoder(
  277. config=config,
  278. quant_config=quant_config,
  279. num_hidden_layers_override=num_hidden_layers_override,
  280. )
  281. self.post_layernorm = nn.LayerNorm(config.hidden_size,
  282. eps=config.layer_norm_eps)
  283. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  284. hidden_states = self.embeddings(pixel_values)
  285. hidden_states = self.encoder(inputs_embeds=hidden_states)
  286. return self.post_layernorm(hidden_states)