clip.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. """Minimal implementation of CLIPVisionModel intended to be only used
  2. within a vision language model."""
  3. from array import array
  4. from typing import Optional
  5. import torch
  6. import torch.nn as nn
  7. from PIL import Image
  8. from transformers import CLIPVisionConfig
  9. from transformers.models.clip.modeling_clip import CLIPAttention
  10. from aphrodite.common.config import ModelConfig
  11. from aphrodite.common.sequence import SequenceData
  12. from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
  13. from aphrodite.inputs import LLMInputs
  14. from aphrodite.modeling.layers.activation import get_act_fn
  15. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  16. RowParallelLinear)
  17. from aphrodite.multimodal.image import (cached_get_tokenizer,
  18. repeat_and_pad_image_tokens)
  19. from aphrodite.quantization import QuantizationConfig
  20. def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
  21. assert image_size % patch_size == 0
  22. return image_size // patch_size
  23. def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
  24. grid_length = get_clip_patch_grid_length(image_size=image_size,
  25. patch_size=patch_size)
  26. return grid_length * grid_length
  27. def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
  28. return get_clip_num_patches(image_size=hf_config.image_size,
  29. patch_size=hf_config.patch_size)
  30. def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
  31. return get_clip_image_feature_size(hf_config)
  32. def dummy_seq_data_for_clip(
  33. hf_config: CLIPVisionConfig,
  34. seq_len: int,
  35. num_images: int,
  36. *,
  37. image_token_id: int,
  38. image_feature_size_override: Optional[int] = None,
  39. ):
  40. if image_feature_size_override is None:
  41. image_feature_size = get_clip_image_feature_size(hf_config)
  42. else:
  43. image_feature_size = image_feature_size_override
  44. token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  45. [image_token_id]) * image_feature_size * num_images
  46. token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  47. [0]) * (seq_len - image_feature_size * num_images)
  48. return SequenceData(token_ids)
  49. def dummy_image_for_clip(
  50. hf_config: CLIPVisionConfig,
  51. num_images: int,
  52. *,
  53. image_width_override: Optional[int] = None,
  54. image_height_override: Optional[int] = None,
  55. ):
  56. width = height = hf_config.image_size
  57. if image_width_override is not None:
  58. width = image_width_override
  59. if image_height_override is not None:
  60. height = image_height_override
  61. image = Image.new("RGB", (width, height), color=0)
  62. return {"image": image if num_images == 1 else [image] * num_images}
  63. def input_processor_for_clip(
  64. model_config: ModelConfig,
  65. hf_config: CLIPVisionConfig,
  66. llm_inputs: LLMInputs,
  67. *,
  68. image_token_id: int,
  69. image_feature_size_override: Optional[int] = None,
  70. ):
  71. multi_modal_data = llm_inputs.get("multi_modal_data")
  72. if multi_modal_data is None or "image" not in multi_modal_data:
  73. return llm_inputs
  74. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  75. if image_feature_size_override is None:
  76. image_data = multi_modal_data["image"]
  77. if isinstance(image_data, Image.Image):
  78. image_feature_size = get_clip_image_feature_size(hf_config)
  79. elif isinstance(image_data, torch.Tensor):
  80. image_feature_size = image_data.shape[0]
  81. else:
  82. raise TypeError(f"Invalid image type: {type(image_data)}")
  83. else:
  84. image_feature_size = image_feature_size_override
  85. new_prompt, new_token_ids = repeat_and_pad_image_tokens(
  86. tokenizer,
  87. llm_inputs.get("prompt"),
  88. llm_inputs["prompt_token_ids"],
  89. image_token_id=image_token_id,
  90. repeat_count=image_feature_size,
  91. )
  92. # NOTE: Create a defensive copy of the original inputs
  93. return LLMInputs(prompt_token_ids=new_token_ids,
  94. prompt=new_prompt,
  95. multi_modal_data=multi_modal_data)
  96. # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
  97. class CLIPVisionEmbeddings(nn.Module):
  98. def __init__(self, config: CLIPVisionConfig):
  99. super().__init__()
  100. self.config = config
  101. self.embed_dim = config.hidden_size
  102. self.image_size = config.image_size
  103. self.patch_size = config.patch_size
  104. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  105. self.patch_embedding = nn.Conv2d(
  106. in_channels=config.num_channels,
  107. out_channels=self.embed_dim,
  108. kernel_size=self.patch_size,
  109. stride=self.patch_size,
  110. bias=False,
  111. )
  112. self.num_patches = get_clip_num_patches(image_size=self.image_size,
  113. patch_size=self.patch_size)
  114. self.num_positions = self.num_patches + 1
  115. self.position_embedding = nn.Embedding(self.num_positions,
  116. self.embed_dim)
  117. self.register_buffer("position_ids",
  118. torch.arange(self.num_positions).expand((1, -1)),
  119. persistent=False)
  120. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  121. batch_size = pixel_values.shape[0]
  122. target_dtype = self.patch_embedding.weight.dtype
  123. patch_embeds = self.patch_embedding(pixel_values.to(
  124. dtype=target_dtype)) # shape = [*, width, grid, grid]
  125. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  126. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  127. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  128. embeddings = embeddings + self.position_embedding(self.position_ids)
  129. return embeddings
  130. class CLIPMLP(nn.Module):
  131. def __init__(self,
  132. config: CLIPVisionConfig,
  133. quant_config: Optional[QuantizationConfig] = None):
  134. super().__init__()
  135. self.config = config
  136. self.activation_fn = get_act_fn(config.hidden_act)
  137. self.fc1 = ColumnParallelLinear(config.hidden_size,
  138. config.intermediate_size,
  139. bias=True,
  140. quant_config=quant_config)
  141. self.fc2 = RowParallelLinear(config.intermediate_size,
  142. config.hidden_size,
  143. bias=True,
  144. quant_config=quant_config)
  145. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  146. hidden_states, _ = self.fc1(hidden_states)
  147. hidden_states = self.activation_fn(hidden_states)
  148. hidden_states, _ = self.fc2(hidden_states)
  149. return hidden_states
  150. class CLIPEncoderLayer(nn.Module):
  151. def __init__(self,
  152. config: CLIPVisionConfig,
  153. quant_config: Optional[QuantizationConfig] = None):
  154. super().__init__()
  155. self.self_attn = CLIPAttention(config)
  156. self.layer_norm1 = nn.LayerNorm(config.hidden_size,
  157. eps=config.layer_norm_eps)
  158. self.mlp = CLIPMLP(config, quant_config=quant_config)
  159. self.layer_norm2 = nn.LayerNorm(config.hidden_size,
  160. eps=config.layer_norm_eps)
  161. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  162. residual = hidden_states
  163. hidden_states = self.layer_norm1(hidden_states)
  164. hidden_states, _ = self.self_attn(hidden_states=hidden_states)
  165. hidden_states = residual + hidden_states
  166. residual = hidden_states
  167. hidden_states = self.layer_norm2(hidden_states)
  168. hidden_states = self.mlp(hidden_states)
  169. hidden_states = residual + hidden_states
  170. return hidden_states
  171. class CLIPEncoder(nn.Module):
  172. """
  173. Transformer encoder consisting of `config.num_hidden_layers` self
  174. attention layers. Each layer is a [`CLIPEncoderLayer`].
  175. Args:
  176. config: CLIPConfig
  177. """
  178. def __init__(self,
  179. config: CLIPVisionConfig,
  180. quant_config: Optional[QuantizationConfig] = None,
  181. num_hidden_layers_override: Optional[int] = None):
  182. super().__init__()
  183. self.config = config
  184. if num_hidden_layers_override is None:
  185. num_hidden_layers = config.num_hidden_layers
  186. else:
  187. num_hidden_layers = num_hidden_layers_override
  188. self.layers = nn.ModuleList([
  189. CLIPEncoderLayer(config=config, quant_config=quant_config)
  190. for _ in range(num_hidden_layers)
  191. ])
  192. def forward(self, inputs_embeds: torch.Tensor):
  193. hidden_states = inputs_embeds
  194. for encoder_layer in self.layers:
  195. hidden_states = encoder_layer(hidden_states)
  196. return hidden_states
  197. class CLIPVisionTransformer(nn.Module):
  198. def __init__(self,
  199. config: CLIPVisionConfig,
  200. quant_config: Optional[QuantizationConfig] = None,
  201. num_hidden_layers_override: Optional[int] = None):
  202. super().__init__()
  203. self.config = config
  204. embed_dim = config.hidden_size
  205. self.embeddings = CLIPVisionEmbeddings(config)
  206. # NOTE: This typo of "layrnorm" is not fixed on purpose to match
  207. # the original transformers code and name of the model weights.
  208. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  209. self.encoder = CLIPEncoder(
  210. config=config,
  211. quant_config=quant_config,
  212. num_hidden_layers_override=num_hidden_layers_override)
  213. def forward(
  214. self,
  215. pixel_values: torch.Tensor,
  216. ) -> torch.Tensor:
  217. hidden_states = self.embeddings(pixel_values)
  218. hidden_states = self.pre_layrnorm(hidden_states)
  219. hidden_states = self.encoder(inputs_embeds=hidden_states)
  220. return hidden_states
  221. class CLIPVisionModel(nn.Module):
  222. config_class = CLIPVisionConfig
  223. main_input_name = "pixel_values"
  224. def __init__(self,
  225. config: CLIPVisionConfig,
  226. quant_config: Optional[QuantizationConfig] = None,
  227. num_hidden_layers_override: Optional[int] = None):
  228. super().__init__()
  229. self.vision_model = CLIPVisionTransformer(
  230. config=config,
  231. quant_config=quant_config,
  232. num_hidden_layers_override=num_hidden_layers_override)
  233. def forward(self, pixel_values: Optional[torch.Tensor] = None):
  234. return self.vision_model(pixel_values=pixel_values)
  235. @property
  236. def device(self):
  237. return next(self.parameters()).device