blip.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. """Minimal implementation of BlipVisionModel intended to be only used
  2. within a vision language model."""
  3. from typing import Optional, Union
  4. import torch
  5. import torch.nn as nn
  6. from PIL import Image
  7. from transformers import Blip2VisionConfig, BlipVisionConfig
  8. from transformers.models.blip.modeling_blip import BlipAttention
  9. from aphrodite.common.config import ModelConfig
  10. from aphrodite.common.sequence import SequenceData
  11. from aphrodite.inputs import LLMInputs
  12. from aphrodite.modeling.layers.activation import get_act_fn
  13. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  14. RowParallelLinear)
  15. from aphrodite.multimodal.image import (cached_get_tokenizer,
  16. repeat_and_pad_image_tokens)
  17. from aphrodite.quantization import QuantizationConfig
  18. def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
  19. assert image_size % patch_size == 0
  20. return image_size // patch_size
  21. def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
  22. grid_length = get_blip_patch_grid_length(image_size=image_size,
  23. patch_size=patch_size)
  24. return grid_length * grid_length
  25. def get_blip_image_feature_size(
  26. hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
  27. return get_blip_num_patches(image_size=hf_config.image_size,
  28. patch_size=hf_config.patch_size)
  29. def get_max_blip_image_tokens(
  30. hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
  31. return get_blip_image_feature_size(hf_config)
  32. def dummy_seq_data_for_blip(
  33. hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
  34. seq_len: int,
  35. *,
  36. image_token_id: int,
  37. image_feature_size_override: Optional[int] = None,
  38. ):
  39. if image_feature_size_override is None:
  40. image_feature_size = get_blip_image_feature_size(hf_config)
  41. else:
  42. image_feature_size = image_feature_size_override
  43. token_ids = [image_token_id] * image_feature_size
  44. token_ids += [0] * (seq_len - image_feature_size)
  45. return SequenceData(token_ids)
  46. def dummy_image_for_blip(
  47. hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
  48. num_images: int,
  49. *,
  50. image_width_override: Optional[int] = None,
  51. image_height_override: Optional[int] = None,
  52. ):
  53. width = height = hf_config.image_size
  54. if image_width_override is not None:
  55. width = image_width_override
  56. if image_height_override is not None:
  57. height = image_height_override
  58. image = Image.new("RGB", (width, height), color=0)
  59. return {"image": image if num_images == 1 else [image] * num_images}
  60. def input_processor_for_blip(
  61. model_config: ModelConfig,
  62. hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
  63. llm_inputs: LLMInputs,
  64. *,
  65. image_token_id: int,
  66. image_feature_size_override: Optional[int] = None,
  67. ):
  68. multi_modal_data = llm_inputs.get("multi_modal_data")
  69. if multi_modal_data is None or "image" not in multi_modal_data:
  70. return llm_inputs
  71. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  72. if image_feature_size_override is None:
  73. image_feature_size = get_blip_image_feature_size(hf_config)
  74. else:
  75. image_feature_size = image_feature_size_override
  76. new_prompt, new_token_ids = repeat_and_pad_image_tokens(
  77. tokenizer,
  78. llm_inputs.get("prompt"),
  79. llm_inputs["prompt_token_ids"],
  80. image_token_id=image_token_id,
  81. repeat_count=image_feature_size,
  82. )
  83. # NOTE: Create a defensive copy of the original inputs
  84. return LLMInputs(prompt_token_ids=new_token_ids,
  85. prompt=new_prompt,
  86. multi_modal_data=multi_modal_data)
  87. # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
  88. class BlipVisionEmbeddings(nn.Module):
  89. def __init__(self, config: BlipVisionConfig):
  90. super().__init__()
  91. self.config = config
  92. self.embed_dim = config.hidden_size
  93. self.image_size = config.image_size
  94. self.patch_size = config.patch_size
  95. self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
  96. self.patch_embedding = nn.Conv2d(
  97. in_channels=3,
  98. out_channels=self.embed_dim,
  99. kernel_size=self.patch_size,
  100. stride=self.patch_size,
  101. )
  102. self.num_patches = get_blip_num_patches(image_size=self.image_size,
  103. patch_size=self.patch_size)
  104. self.num_positions = self.num_patches + 1
  105. self.position_embedding = nn.Parameter(
  106. torch.randn(1, self.num_positions, self.embed_dim))
  107. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  108. batch_size = pixel_values.shape[0]
  109. target_dtype = self.patch_embedding.weight.dtype
  110. patch_embeds = self.patch_embedding(pixel_values.to(
  111. dtype=target_dtype)) # shape = [*, width, grid, grid]
  112. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  113. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  114. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  115. position_embeds = self.position_embedding.to(target_dtype)
  116. embeddings = embeddings + position_embeds[:, :embeddings.size(1), :]
  117. return embeddings
  118. class BlipMLP(nn.Module):
  119. def __init__(self,
  120. config: BlipVisionConfig,
  121. quant_config: Optional[QuantizationConfig] = None):
  122. super().__init__()
  123. self.config = config
  124. self.activation_fn = get_act_fn(config.hidden_act)
  125. self.fc1 = ColumnParallelLinear(config.hidden_size,
  126. config.intermediate_size,
  127. bias=True,
  128. quant_config=quant_config)
  129. self.fc2 = RowParallelLinear(config.intermediate_size,
  130. config.hidden_size,
  131. bias=True,
  132. quant_config=quant_config)
  133. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  134. hidden_states, _ = self.fc1(hidden_states)
  135. hidden_states = self.activation_fn(hidden_states)
  136. hidden_states, _ = self.fc2(hidden_states)
  137. return hidden_states
  138. class BlipEncoderLayer(nn.Module):
  139. def __init__(self,
  140. config: BlipVisionConfig,
  141. quant_config: Optional[QuantizationConfig] = None):
  142. super().__init__()
  143. self.self_attn = BlipAttention(config)
  144. self.layer_norm1 = nn.LayerNorm(config.hidden_size,
  145. eps=config.layer_norm_eps)
  146. self.mlp = BlipMLP(config, quant_config=quant_config)
  147. self.layer_norm2 = nn.LayerNorm(config.hidden_size,
  148. eps=config.layer_norm_eps)
  149. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  150. residual = hidden_states
  151. hidden_states = self.layer_norm1(hidden_states)
  152. hidden_states, _ = self.self_attn(hidden_states=hidden_states)
  153. hidden_states = residual + hidden_states
  154. residual = hidden_states
  155. hidden_states = self.layer_norm2(hidden_states)
  156. hidden_states = self.mlp(hidden_states)
  157. hidden_states = residual + hidden_states
  158. return hidden_states
  159. class BlipEncoder(nn.Module):
  160. """
  161. Transformer encoder consisting of `config.num_hidden_layers` self
  162. attention layers. Each layer is a [`BlipEncoderLayer`].
  163. Args:
  164. config: BlipConfig
  165. """
  166. def __init__(self,
  167. config: BlipVisionConfig,
  168. quant_config: Optional[QuantizationConfig] = None,
  169. num_hidden_layers_override: Optional[int] = None):
  170. super().__init__()
  171. self.config = config
  172. if num_hidden_layers_override is None:
  173. num_hidden_layers = config.num_hidden_layers
  174. else:
  175. num_hidden_layers = num_hidden_layers_override
  176. self.layers = nn.ModuleList([
  177. BlipEncoderLayer(config=config, quant_config=quant_config)
  178. for _ in range(num_hidden_layers)
  179. ])
  180. def forward(self, inputs_embeds: torch.Tensor):
  181. hidden_states = inputs_embeds
  182. for encoder_layer in self.layers:
  183. hidden_states = encoder_layer(hidden_states)
  184. return hidden_states
  185. class BlipVisionModel(nn.Module):
  186. config_class = BlipVisionConfig
  187. main_input_name = "pixel_values"
  188. def __init__(self,
  189. config: BlipVisionConfig,
  190. quant_config: Optional[QuantizationConfig] = None,
  191. num_hidden_layers_override: Optional[int] = None):
  192. super().__init__()
  193. self.config = config
  194. self.embeddings = BlipVisionEmbeddings(config)
  195. self.encoder = BlipEncoder(
  196. config=config,
  197. quant_config=quant_config,
  198. num_hidden_layers_override=num_hidden_layers_override,
  199. )
  200. self.post_layernorm = nn.LayerNorm(config.hidden_size,
  201. eps=config.layer_norm_eps)
  202. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  203. hidden_states = self.embeddings(pixel_values)
  204. hidden_states = self.encoder(inputs_embeds=hidden_states)
  205. return self.post_layernorm(hidden_states)