blip.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. """Minimal implementation of BlipVisionModel intended to be only used
  2. within a vision language model."""
  3. from typing import Iterable, Optional, Tuple, Union
  4. import torch
  5. import torch.nn as nn
  6. from PIL import Image
  7. from transformers import Blip2VisionConfig, BlipVisionConfig
  8. from transformers.models.blip.modeling_blip import BlipAttention
  9. from aphrodite.common.config import ModelConfig
  10. from aphrodite.common.sequence import SequenceData
  11. from aphrodite.distributed import divide, get_tensor_model_parallel_world_size
  12. from aphrodite.inputs import LLMInputs
  13. from aphrodite.modeling.layers.activation import get_act_fn
  14. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  15. QKVParallelLinear,
  16. RowParallelLinear)
  17. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  18. from aphrodite.multimodal.utils import (cached_get_tokenizer,
  19. repeat_and_pad_placeholder_tokens)
  20. from aphrodite.quantization import QuantizationConfig
  21. try:
  22. from xformers import ops as xops
  23. USE_XFORMERS_OPS = True
  24. except ImportError:
  25. USE_XFORMERS_OPS = False
  26. def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
  27. assert image_size % patch_size == 0
  28. return image_size // patch_size
  29. def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
  30. grid_length = get_blip_patch_grid_length(image_size=image_size,
  31. patch_size=patch_size)
  32. return grid_length * grid_length
  33. def get_blip_image_feature_size(
  34. hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
  35. return get_blip_num_patches(image_size=hf_config.image_size,
  36. patch_size=hf_config.patch_size)
  37. def get_max_blip_image_tokens(
  38. hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
  39. return get_blip_image_feature_size(hf_config)
  40. def dummy_seq_data_for_blip(
  41. hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
  42. seq_len: int,
  43. num_images: int,
  44. *,
  45. image_token_id: int,
  46. image_feature_size_override: Optional[int] = None,
  47. ):
  48. if image_feature_size_override is None:
  49. image_feature_size = get_blip_image_feature_size(hf_config)
  50. else:
  51. image_feature_size = image_feature_size_override
  52. return SequenceData.from_token_counts(
  53. (image_token_id, image_feature_size * num_images),
  54. (0, seq_len - image_feature_size * num_images),
  55. )
  56. def dummy_image_for_blip(
  57. hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
  58. num_images: int,
  59. *,
  60. image_width_override: Optional[int] = None,
  61. image_height_override: Optional[int] = None,
  62. ):
  63. width = height = hf_config.image_size
  64. if image_width_override is not None:
  65. width = image_width_override
  66. if image_height_override is not None:
  67. height = image_height_override
  68. image = Image.new("RGB", (width, height), color=0)
  69. return {"image": image if num_images == 1 else [image] * num_images}
  70. def input_processor_for_blip(
  71. model_config: ModelConfig,
  72. hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
  73. llm_inputs: LLMInputs,
  74. *,
  75. image_token_id: int,
  76. image_feature_size_override: Optional[int] = None,
  77. ):
  78. multi_modal_data = llm_inputs.get("multi_modal_data")
  79. if multi_modal_data is None or "image" not in multi_modal_data:
  80. return llm_inputs
  81. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  82. if image_feature_size_override is None:
  83. image_feature_size = get_blip_image_feature_size(hf_config)
  84. else:
  85. image_feature_size = image_feature_size_override
  86. new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
  87. tokenizer,
  88. llm_inputs.get("prompt"),
  89. llm_inputs["prompt_token_ids"],
  90. placeholder_token_id=image_token_id,
  91. repeat_count=image_feature_size,
  92. )
  93. # NOTE: Create a defensive copy of the original inputs
  94. return LLMInputs(prompt_token_ids=new_token_ids,
  95. prompt=new_prompt,
  96. multi_modal_data=multi_modal_data)
  97. # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
  98. class BlipVisionEmbeddings(nn.Module):
  99. def __init__(self, config: BlipVisionConfig):
  100. super().__init__()
  101. self.config = config
  102. self.embed_dim = config.hidden_size
  103. self.image_size = config.image_size
  104. self.patch_size = config.patch_size
  105. self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
  106. self.patch_embedding = nn.Conv2d(
  107. in_channels=3,
  108. out_channels=self.embed_dim,
  109. kernel_size=self.patch_size,
  110. stride=self.patch_size,
  111. )
  112. self.num_patches = get_blip_num_patches(image_size=self.image_size,
  113. patch_size=self.patch_size)
  114. self.num_positions = self.num_patches + 1
  115. self.position_embedding = nn.Parameter(
  116. torch.randn(1, self.num_positions, self.embed_dim))
  117. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  118. batch_size = pixel_values.shape[0]
  119. target_dtype = self.patch_embedding.weight.dtype
  120. patch_embeds = self.patch_embedding(pixel_values.to(
  121. dtype=target_dtype)) # shape = [*, width, grid, grid]
  122. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  123. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  124. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  125. position_embeds = self.position_embedding.to(target_dtype)
  126. embeddings = embeddings + position_embeds[:, :embeddings.size(1), :]
  127. return embeddings
  128. class BlipParallelAttention(nn.Module):
  129. """Multi-headed attention from 'Attention Is All You Need' paper"""
  130. def __init__(
  131. self,
  132. config: BlipVisionConfig,
  133. quant_config: Optional[QuantizationConfig] = None,
  134. ):
  135. super().__init__()
  136. self.config = config
  137. self.embed_dim = config.hidden_size
  138. self.num_heads = config.num_attention_heads
  139. self.head_dim = self.embed_dim // self.num_heads
  140. if self.head_dim * self.num_heads != self.embed_dim:
  141. raise ValueError(
  142. "embed_dim must be divisible by num_heads "
  143. f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
  144. f" {self.num_heads}).")
  145. self.scale = self.head_dim**-0.5
  146. self.dropout = config.attention_dropout
  147. self.qkv = QKVParallelLinear(
  148. self.embed_dim,
  149. self.head_dim,
  150. self.num_heads,
  151. bias=config.qkv_bias,
  152. quant_config=quant_config,
  153. )
  154. self.projection = RowParallelLinear(
  155. self.embed_dim,
  156. self.embed_dim,
  157. quant_config=quant_config,
  158. )
  159. self.tp_size = get_tensor_model_parallel_world_size()
  160. self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
  161. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  162. return tensor.view(bsz, seq_len, self.num_heads,
  163. self.head_dim).transpose(1, 2).contiguous()
  164. def forward(
  165. self,
  166. hidden_states: torch.Tensor,
  167. ):
  168. """Input shape: Batch x Time x Channel"""
  169. bsz, tgt_len, _ = hidden_states.size()
  170. qkv_states, _ = self.qkv(hidden_states)
  171. query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
  172. query_states = query_states.view(bsz, tgt_len,
  173. self.num_heads_per_partition,
  174. self.head_dim)
  175. key_states = key_states.view(bsz, tgt_len,
  176. self.num_heads_per_partition,
  177. self.head_dim)
  178. value_states = value_states.view(bsz, tgt_len,
  179. self.num_heads_per_partition,
  180. self.head_dim)
  181. out = xops.memory_efficient_attention_forward(query_states,
  182. key_states,
  183. value_states,
  184. p=self.dropout,
  185. scale=self.scale)
  186. out = out.view(bsz, tgt_len, -1)
  187. attn_output, _ = self.projection(out)
  188. return attn_output, None
  189. class BlipMLP(nn.Module):
  190. def __init__(self,
  191. config: BlipVisionConfig,
  192. quant_config: Optional[QuantizationConfig] = None):
  193. super().__init__()
  194. self.config = config
  195. self.activation_fn = get_act_fn(config.hidden_act)
  196. self.fc1 = ColumnParallelLinear(config.hidden_size,
  197. config.intermediate_size,
  198. bias=True,
  199. quant_config=quant_config)
  200. self.fc2 = RowParallelLinear(config.intermediate_size,
  201. config.hidden_size,
  202. bias=True,
  203. quant_config=quant_config)
  204. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  205. hidden_states, _ = self.fc1(hidden_states)
  206. hidden_states = self.activation_fn(hidden_states)
  207. hidden_states, _ = self.fc2(hidden_states)
  208. return hidden_states
  209. class BlipEncoderLayer(nn.Module):
  210. def __init__(self,
  211. config: BlipVisionConfig,
  212. quant_config: Optional[QuantizationConfig] = None):
  213. super().__init__()
  214. # fallback to sdpa attention if tp unavailable
  215. num_heads = config.num_attention_heads
  216. tp_size = get_tensor_model_parallel_world_size()
  217. if USE_XFORMERS_OPS and num_heads % tp_size == 0:
  218. self.self_attn = BlipParallelAttention(config,
  219. quant_config=quant_config)
  220. else:
  221. # Blip doesn't have SDPA attention implemented in transformers
  222. # use eager attention instead for cpu backend
  223. self.self_attn = BlipAttention(config)
  224. self.layer_norm1 = nn.LayerNorm(config.hidden_size,
  225. eps=config.layer_norm_eps)
  226. self.mlp = BlipMLP(config, quant_config=quant_config)
  227. self.layer_norm2 = nn.LayerNorm(config.hidden_size,
  228. eps=config.layer_norm_eps)
  229. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  230. residual = hidden_states
  231. hidden_states = self.layer_norm1(hidden_states)
  232. hidden_states, _ = self.self_attn(hidden_states=hidden_states)
  233. hidden_states = residual + hidden_states
  234. residual = hidden_states
  235. hidden_states = self.layer_norm2(hidden_states)
  236. hidden_states = self.mlp(hidden_states)
  237. hidden_states = residual + hidden_states
  238. return hidden_states
  239. class BlipEncoder(nn.Module):
  240. """
  241. Transformer encoder consisting of `config.num_hidden_layers` self
  242. attention layers. Each layer is a [`BlipEncoderLayer`].
  243. Args:
  244. config: BlipConfig
  245. """
  246. def __init__(self,
  247. config: BlipVisionConfig,
  248. quant_config: Optional[QuantizationConfig] = None,
  249. num_hidden_layers_override: Optional[int] = None):
  250. super().__init__()
  251. self.config = config
  252. if num_hidden_layers_override is None:
  253. num_hidden_layers = config.num_hidden_layers
  254. else:
  255. num_hidden_layers = num_hidden_layers_override
  256. self.layers = nn.ModuleList([
  257. BlipEncoderLayer(config=config, quant_config=quant_config)
  258. for _ in range(num_hidden_layers)
  259. ])
  260. def forward(self, inputs_embeds: torch.Tensor):
  261. hidden_states = inputs_embeds
  262. for encoder_layer in self.layers:
  263. hidden_states = encoder_layer(hidden_states)
  264. return hidden_states
  265. class BlipVisionModel(nn.Module):
  266. config_class = BlipVisionConfig
  267. main_input_name = "pixel_values"
  268. def __init__(self,
  269. config: BlipVisionConfig,
  270. quant_config: Optional[QuantizationConfig] = None,
  271. num_hidden_layers_override: Optional[int] = None):
  272. super().__init__()
  273. tp_size = get_tensor_model_parallel_world_size()
  274. num_heads = config.num_attention_heads
  275. self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
  276. self.config = config
  277. self.embeddings = BlipVisionEmbeddings(config)
  278. self.encoder = BlipEncoder(
  279. config=config,
  280. quant_config=quant_config,
  281. num_hidden_layers_override=num_hidden_layers_override,
  282. )
  283. if len(self.encoder.layers) > config.num_hidden_layers:
  284. raise ValueError(
  285. f"The original encoder only has {config.num_hidden_layers} "
  286. f"layers, but you requested {len(self.encoder.layers)} layers."
  287. )
  288. elif len(self.encoder.layers) == config.num_hidden_layers:
  289. self.post_layernorm = nn.LayerNorm(config.hidden_size,
  290. eps=config.layer_norm_eps)
  291. else:
  292. # post_layernorm is unused when we extract intermediate features
  293. # In this case, we can skip it to conserve memory
  294. self.post_layernorm = None
  295. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  296. hidden_states = self.embeddings(pixel_values)
  297. hidden_states = self.encoder(inputs_embeds=hidden_states)
  298. if self.post_layernorm is None:
  299. return hidden_states
  300. return self.post_layernorm(hidden_states)
  301. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  302. stacked_params_mapping = [
  303. # (param_name, shard_name, shard_id)
  304. ("qkv_proj", "q_proj", "q"),
  305. ("qkv_proj", "k_proj", "k"),
  306. ("qkv_proj", "v_proj", "v"),
  307. ] if self.shard_weight else []
  308. params_dict = dict(self.named_parameters())
  309. layer_count = len(self.encoder.layers)
  310. for name, loaded_weight in weights:
  311. # post_layernorm is not needed in BlipVisionModel
  312. if (name.startswith("post_layernorm")
  313. and self.post_layernorm is None):
  314. continue
  315. # omit layers when num_hidden_layers_override is set
  316. if name.startswith("encoder.layers"):
  317. layer_idx = int(name.split(".")[2])
  318. if layer_idx >= layer_count:
  319. continue
  320. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  321. if weight_name not in name:
  322. continue
  323. param = params_dict[name.replace(weight_name, param_name)]
  324. weight_loader = param.weight_loader
  325. weight_loader(param, loaded_weight, shard_id)
  326. break
  327. else:
  328. param = params_dict[name]
  329. weight_loader = getattr(param, "weight_loader",
  330. default_weight_loader)
  331. weight_loader(param, loaded_weight)