blip.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. """Minimal implementation of BlipVisionModel intended to be only used
  2. within a vision language model."""
  3. from typing import Optional, Union
  4. import torch
  5. import torch.nn as nn
  6. from PIL import Image
  7. from transformers import Blip2VisionConfig, BlipVisionConfig
  8. from transformers.models.blip.modeling_blip import BlipAttention
  9. from aphrodite.common.config import ModelConfig
  10. from aphrodite.common.sequence import SequenceData
  11. from aphrodite.inputs import LLMInputs
  12. from aphrodite.modeling.layers.activation import get_act_fn
  13. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  14. RowParallelLinear)
  15. from aphrodite.multimodal.image import (cached_get_tokenizer,
  16. repeat_and_pad_image_tokens)
  17. from aphrodite.quantization import QuantizationConfig
  18. def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
  19. assert image_size % patch_size == 0
  20. return image_size // patch_size
  21. def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
  22. grid_length = get_blip_patch_grid_length(image_size=image_size,
  23. patch_size=patch_size)
  24. return grid_length * grid_length
  25. def get_blip_image_feature_size(
  26. hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
  27. return get_blip_num_patches(image_size=hf_config.image_size,
  28. patch_size=hf_config.patch_size)
  29. def get_max_blip_image_tokens(
  30. hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
  31. return get_blip_image_feature_size(hf_config)
  32. def dummy_seq_data_for_blip(
  33. hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
  34. seq_len: int,
  35. *,
  36. image_token_id: int,
  37. image_feature_size_override: Optional[int] = None,
  38. ):
  39. if image_feature_size_override is None:
  40. image_feature_size = get_blip_image_feature_size(hf_config)
  41. else:
  42. image_feature_size = image_feature_size_override
  43. token_ids = [image_token_id] * image_feature_size
  44. token_ids += [0] * (seq_len - image_feature_size)
  45. return SequenceData(token_ids)
  46. def dummy_image_for_blip(
  47. hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
  48. *,
  49. image_width_override: Optional[int] = None,
  50. image_height_override: Optional[int] = None,
  51. ):
  52. width = height = hf_config.image_size
  53. if image_width_override is not None:
  54. width = image_width_override
  55. if image_height_override is not None:
  56. height = image_height_override
  57. image = Image.new("RGB", (width, height), color=0)
  58. return {"image": image}
  59. def input_processor_for_blip(
  60. model_config: ModelConfig,
  61. hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
  62. llm_inputs: LLMInputs,
  63. *,
  64. image_token_id: int,
  65. image_feature_size_override: Optional[int] = None,
  66. ):
  67. multi_modal_data = llm_inputs.get("multi_modal_data")
  68. if multi_modal_data is None or "image" not in multi_modal_data:
  69. return llm_inputs
  70. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  71. if image_feature_size_override is None:
  72. image_feature_size = get_blip_image_feature_size(hf_config)
  73. else:
  74. image_feature_size = image_feature_size_override
  75. new_prompt, new_token_ids = repeat_and_pad_image_tokens(
  76. tokenizer,
  77. llm_inputs.get("prompt"),
  78. llm_inputs["prompt_token_ids"],
  79. image_token_id=image_token_id,
  80. repeat_count=image_feature_size,
  81. )
  82. # NOTE: Create a defensive copy of the original inputs
  83. return LLMInputs(prompt_token_ids=new_token_ids,
  84. prompt=new_prompt,
  85. multi_modal_data=multi_modal_data)
  86. # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
  87. class BlipVisionEmbeddings(nn.Module):
  88. def __init__(self, config: BlipVisionConfig):
  89. super().__init__()
  90. self.config = config
  91. self.embed_dim = config.hidden_size
  92. self.image_size = config.image_size
  93. self.patch_size = config.patch_size
  94. self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
  95. self.patch_embedding = nn.Conv2d(
  96. in_channels=3,
  97. out_channels=self.embed_dim,
  98. kernel_size=self.patch_size,
  99. stride=self.patch_size,
  100. )
  101. self.num_patches = get_blip_num_patches(image_size=self.image_size,
  102. patch_size=self.patch_size)
  103. self.num_positions = self.num_patches + 1
  104. self.position_embedding = nn.Parameter(
  105. torch.randn(1, self.num_positions, self.embed_dim))
  106. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  107. batch_size = pixel_values.shape[0]
  108. target_dtype = self.patch_embedding.weight.dtype
  109. patch_embeds = self.patch_embedding(pixel_values.to(
  110. dtype=target_dtype)) # shape = [*, width, grid, grid]
  111. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  112. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  113. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  114. position_embeds = self.position_embedding.to(target_dtype)
  115. embeddings = embeddings + position_embeds[:, :embeddings.size(1), :]
  116. return embeddings
  117. class BlipMLP(nn.Module):
  118. def __init__(self,
  119. config: BlipVisionConfig,
  120. quant_config: Optional[QuantizationConfig] = None):
  121. super().__init__()
  122. self.config = config
  123. self.activation_fn = get_act_fn(config.hidden_act)
  124. self.fc1 = ColumnParallelLinear(config.hidden_size,
  125. config.intermediate_size,
  126. bias=True,
  127. quant_config=quant_config)
  128. self.fc2 = RowParallelLinear(config.intermediate_size,
  129. config.hidden_size,
  130. bias=True,
  131. quant_config=quant_config)
  132. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  133. hidden_states, _ = self.fc1(hidden_states)
  134. hidden_states = self.activation_fn(hidden_states)
  135. hidden_states, _ = self.fc2(hidden_states)
  136. return hidden_states
  137. class BlipEncoderLayer(nn.Module):
  138. def __init__(self,
  139. config: BlipVisionConfig,
  140. quant_config: Optional[QuantizationConfig] = None):
  141. super().__init__()
  142. self.self_attn = BlipAttention(config)
  143. self.layer_norm1 = nn.LayerNorm(config.hidden_size,
  144. eps=config.layer_norm_eps)
  145. self.mlp = BlipMLP(config, quant_config=quant_config)
  146. self.layer_norm2 = nn.LayerNorm(config.hidden_size,
  147. eps=config.layer_norm_eps)
  148. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  149. residual = hidden_states
  150. hidden_states = self.layer_norm1(hidden_states)
  151. hidden_states, _ = self.self_attn(hidden_states=hidden_states)
  152. hidden_states = residual + hidden_states
  153. residual = hidden_states
  154. hidden_states = self.layer_norm2(hidden_states)
  155. hidden_states = self.mlp(hidden_states)
  156. hidden_states = residual + hidden_states
  157. return hidden_states
  158. class BlipEncoder(nn.Module):
  159. """
  160. Transformer encoder consisting of `config.num_hidden_layers` self
  161. attention layers. Each layer is a [`BlipEncoderLayer`].
  162. Args:
  163. config: BlipConfig
  164. """
  165. def __init__(self,
  166. config: BlipVisionConfig,
  167. quant_config: Optional[QuantizationConfig] = None,
  168. num_hidden_layers_override: Optional[int] = None):
  169. super().__init__()
  170. self.config = config
  171. if num_hidden_layers_override is None:
  172. num_hidden_layers = config.num_hidden_layers
  173. else:
  174. num_hidden_layers = num_hidden_layers_override
  175. self.layers = nn.ModuleList([
  176. BlipEncoderLayer(config=config, quant_config=quant_config)
  177. for _ in range(num_hidden_layers)
  178. ])
  179. def forward(self, inputs_embeds: torch.Tensor):
  180. hidden_states = inputs_embeds
  181. for encoder_layer in self.layers:
  182. hidden_states = encoder_layer(hidden_states)
  183. return hidden_states
  184. class BlipVisionModel(nn.Module):
  185. config_class = BlipVisionConfig
  186. main_input_name = "pixel_values"
  187. def __init__(self,
  188. config: BlipVisionConfig,
  189. quant_config: Optional[QuantizationConfig] = None,
  190. num_hidden_layers_override: Optional[int] = None):
  191. super().__init__()
  192. self.config = config
  193. self.embeddings = BlipVisionEmbeddings(config)
  194. self.encoder = BlipEncoder(
  195. config=config,
  196. quant_config=quant_config,
  197. num_hidden_layers_override=num_hidden_layers_override,
  198. )
  199. self.post_layernorm = nn.LayerNorm(config.hidden_size,
  200. eps=config.layer_norm_eps)
  201. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  202. hidden_states = self.embeddings(pixel_values)
  203. hidden_states = self.encoder(inputs_embeds=hidden_states)
  204. return self.post_layernorm(hidden_states)