blip.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. """Minimal implementation of BlipVisionModel intended to be only used
  2. within a vision language model."""
  3. from array import array
  4. from typing import Optional, Union
  5. import torch
  6. import torch.nn as nn
  7. from PIL import Image
  8. from transformers import Blip2VisionConfig, BlipVisionConfig
  9. from transformers.models.blip.modeling_blip import BlipAttention
  10. from aphrodite.common.config import ModelConfig
  11. from aphrodite.common.sequence import SequenceData
  12. from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
  13. from aphrodite.inputs import LLMInputs
  14. from aphrodite.modeling.layers.activation import get_act_fn
  15. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  16. RowParallelLinear)
  17. from aphrodite.multimodal.image import (cached_get_tokenizer,
  18. repeat_and_pad_image_tokens)
  19. from aphrodite.quantization import QuantizationConfig
  20. def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
  21. assert image_size % patch_size == 0
  22. return image_size // patch_size
  23. def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
  24. grid_length = get_blip_patch_grid_length(image_size=image_size,
  25. patch_size=patch_size)
  26. return grid_length * grid_length
  27. def get_blip_image_feature_size(
  28. hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
  29. return get_blip_num_patches(image_size=hf_config.image_size,
  30. patch_size=hf_config.patch_size)
  31. def get_max_blip_image_tokens(
  32. hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
  33. return get_blip_image_feature_size(hf_config)
  34. def dummy_seq_data_for_blip(
  35. hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
  36. seq_len: int,
  37. *,
  38. image_token_id: int,
  39. image_feature_size_override: Optional[int] = None,
  40. ):
  41. if image_feature_size_override is None:
  42. image_feature_size = get_blip_image_feature_size(hf_config)
  43. else:
  44. image_feature_size = image_feature_size_override
  45. token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  46. [image_token_id]) * image_feature_size
  47. token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  48. [0]) * (seq_len - image_feature_size)
  49. return SequenceData(token_ids)
  50. def dummy_image_for_blip(
  51. hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
  52. num_images: int,
  53. *,
  54. image_width_override: Optional[int] = None,
  55. image_height_override: Optional[int] = None,
  56. ):
  57. width = height = hf_config.image_size
  58. if image_width_override is not None:
  59. width = image_width_override
  60. if image_height_override is not None:
  61. height = image_height_override
  62. image = Image.new("RGB", (width, height), color=0)
  63. return {"image": image if num_images == 1 else [image] * num_images}
  64. def input_processor_for_blip(
  65. model_config: ModelConfig,
  66. hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
  67. llm_inputs: LLMInputs,
  68. *,
  69. image_token_id: int,
  70. image_feature_size_override: Optional[int] = None,
  71. ):
  72. multi_modal_data = llm_inputs.get("multi_modal_data")
  73. if multi_modal_data is None or "image" not in multi_modal_data:
  74. return llm_inputs
  75. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  76. if image_feature_size_override is None:
  77. image_feature_size = get_blip_image_feature_size(hf_config)
  78. else:
  79. image_feature_size = image_feature_size_override
  80. new_prompt, new_token_ids = repeat_and_pad_image_tokens(
  81. tokenizer,
  82. llm_inputs.get("prompt"),
  83. llm_inputs["prompt_token_ids"],
  84. image_token_id=image_token_id,
  85. repeat_count=image_feature_size,
  86. )
  87. # NOTE: Create a defensive copy of the original inputs
  88. return LLMInputs(prompt_token_ids=new_token_ids,
  89. prompt=new_prompt,
  90. multi_modal_data=multi_modal_data)
  91. # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
  92. class BlipVisionEmbeddings(nn.Module):
  93. def __init__(self, config: BlipVisionConfig):
  94. super().__init__()
  95. self.config = config
  96. self.embed_dim = config.hidden_size
  97. self.image_size = config.image_size
  98. self.patch_size = config.patch_size
  99. self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
  100. self.patch_embedding = nn.Conv2d(
  101. in_channels=3,
  102. out_channels=self.embed_dim,
  103. kernel_size=self.patch_size,
  104. stride=self.patch_size,
  105. )
  106. self.num_patches = get_blip_num_patches(image_size=self.image_size,
  107. patch_size=self.patch_size)
  108. self.num_positions = self.num_patches + 1
  109. self.position_embedding = nn.Parameter(
  110. torch.randn(1, self.num_positions, self.embed_dim))
  111. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  112. batch_size = pixel_values.shape[0]
  113. target_dtype = self.patch_embedding.weight.dtype
  114. patch_embeds = self.patch_embedding(pixel_values.to(
  115. dtype=target_dtype)) # shape = [*, width, grid, grid]
  116. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  117. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  118. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  119. position_embeds = self.position_embedding.to(target_dtype)
  120. embeddings = embeddings + position_embeds[:, :embeddings.size(1), :]
  121. return embeddings
  122. class BlipMLP(nn.Module):
  123. def __init__(self,
  124. config: BlipVisionConfig,
  125. quant_config: Optional[QuantizationConfig] = None):
  126. super().__init__()
  127. self.config = config
  128. self.activation_fn = get_act_fn(config.hidden_act)
  129. self.fc1 = ColumnParallelLinear(config.hidden_size,
  130. config.intermediate_size,
  131. bias=True,
  132. quant_config=quant_config)
  133. self.fc2 = RowParallelLinear(config.intermediate_size,
  134. config.hidden_size,
  135. bias=True,
  136. quant_config=quant_config)
  137. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  138. hidden_states, _ = self.fc1(hidden_states)
  139. hidden_states = self.activation_fn(hidden_states)
  140. hidden_states, _ = self.fc2(hidden_states)
  141. return hidden_states
  142. class BlipEncoderLayer(nn.Module):
  143. def __init__(self,
  144. config: BlipVisionConfig,
  145. quant_config: Optional[QuantizationConfig] = None):
  146. super().__init__()
  147. self.self_attn = BlipAttention(config)
  148. self.layer_norm1 = nn.LayerNorm(config.hidden_size,
  149. eps=config.layer_norm_eps)
  150. self.mlp = BlipMLP(config, quant_config=quant_config)
  151. self.layer_norm2 = nn.LayerNorm(config.hidden_size,
  152. eps=config.layer_norm_eps)
  153. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  154. residual = hidden_states
  155. hidden_states = self.layer_norm1(hidden_states)
  156. hidden_states, _ = self.self_attn(hidden_states=hidden_states)
  157. hidden_states = residual + hidden_states
  158. residual = hidden_states
  159. hidden_states = self.layer_norm2(hidden_states)
  160. hidden_states = self.mlp(hidden_states)
  161. hidden_states = residual + hidden_states
  162. return hidden_states
  163. class BlipEncoder(nn.Module):
  164. """
  165. Transformer encoder consisting of `config.num_hidden_layers` self
  166. attention layers. Each layer is a [`BlipEncoderLayer`].
  167. Args:
  168. config: BlipConfig
  169. """
  170. def __init__(self,
  171. config: BlipVisionConfig,
  172. quant_config: Optional[QuantizationConfig] = None,
  173. num_hidden_layers_override: Optional[int] = None):
  174. super().__init__()
  175. self.config = config
  176. if num_hidden_layers_override is None:
  177. num_hidden_layers = config.num_hidden_layers
  178. else:
  179. num_hidden_layers = num_hidden_layers_override
  180. self.layers = nn.ModuleList([
  181. BlipEncoderLayer(config=config, quant_config=quant_config)
  182. for _ in range(num_hidden_layers)
  183. ])
  184. def forward(self, inputs_embeds: torch.Tensor):
  185. hidden_states = inputs_embeds
  186. for encoder_layer in self.layers:
  187. hidden_states = encoder_layer(hidden_states)
  188. return hidden_states
  189. class BlipVisionModel(nn.Module):
  190. config_class = BlipVisionConfig
  191. main_input_name = "pixel_values"
  192. def __init__(self,
  193. config: BlipVisionConfig,
  194. quant_config: Optional[QuantizationConfig] = None,
  195. num_hidden_layers_override: Optional[int] = None):
  196. super().__init__()
  197. self.config = config
  198. self.embeddings = BlipVisionEmbeddings(config)
  199. self.encoder = BlipEncoder(
  200. config=config,
  201. quant_config=quant_config,
  202. num_hidden_layers_override=num_hidden_layers_override,
  203. )
  204. self.post_layernorm = nn.LayerNorm(config.hidden_size,
  205. eps=config.layer_norm_eps)
  206. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  207. hidden_states = self.embeddings(pixel_values)
  208. hidden_states = self.encoder(inputs_embeds=hidden_states)
  209. return self.post_layernorm(hidden_states)