clip.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. """Minimal implementation of CLIPVisionModel intended to be only used
  2. within a vision language model."""
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. from PIL import Image
  7. from transformers import CLIPVisionConfig
  8. from transformers.models.clip.modeling_clip import CLIPAttention
  9. from aphrodite.common.config import ModelConfig
  10. from aphrodite.common.sequence import SequenceData
  11. from aphrodite.inputs import LLMInputs
  12. from aphrodite.modeling.layers.activation import get_act_fn
  13. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  14. RowParallelLinear)
  15. from aphrodite.multimodal.image import (cached_get_tokenizer,
  16. repeat_and_pad_image_tokens)
  17. from aphrodite.quantization import QuantizationConfig
  18. def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
  19. assert image_size % patch_size == 0
  20. return image_size // patch_size
  21. def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
  22. grid_length = get_clip_patch_grid_length(image_size=image_size,
  23. patch_size=patch_size)
  24. return grid_length * grid_length
  25. def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
  26. return get_clip_num_patches(image_size=hf_config.image_size,
  27. patch_size=hf_config.patch_size)
  28. def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
  29. return get_clip_image_feature_size(hf_config)
  30. def dummy_seq_data_for_clip(
  31. hf_config: CLIPVisionConfig,
  32. seq_len: int,
  33. num_images: int,
  34. *,
  35. image_token_id: int,
  36. image_feature_size_override: Optional[int] = None,
  37. ):
  38. if image_feature_size_override is None:
  39. image_feature_size = get_clip_image_feature_size(hf_config)
  40. else:
  41. image_feature_size = image_feature_size_override
  42. token_ids = [image_token_id] * image_feature_size * num_images
  43. token_ids += [0] * (seq_len - image_feature_size * num_images)
  44. return SequenceData(token_ids)
  45. def dummy_image_for_clip(
  46. hf_config: CLIPVisionConfig,
  47. num_images: int,
  48. *,
  49. image_width_override: Optional[int] = None,
  50. image_height_override: Optional[int] = None,
  51. ):
  52. width = height = hf_config.image_size
  53. if image_width_override is not None:
  54. width = image_width_override
  55. if image_height_override is not None:
  56. height = image_height_override
  57. image = Image.new("RGB", (width, height), color=0)
  58. return {"image": image if num_images == 1 else [image] * num_images}
  59. def input_processor_for_clip(
  60. model_config: ModelConfig,
  61. hf_config: CLIPVisionConfig,
  62. llm_inputs: LLMInputs,
  63. *,
  64. image_token_id: int,
  65. image_feature_size_override: Optional[int] = None,
  66. ):
  67. multi_modal_data = llm_inputs.get("multi_modal_data")
  68. if multi_modal_data is None or "image" not in multi_modal_data:
  69. return llm_inputs
  70. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  71. if image_feature_size_override is None:
  72. image_data = multi_modal_data["image"]
  73. if isinstance(image_data, Image.Image):
  74. image_feature_size = get_clip_image_feature_size(hf_config)
  75. elif isinstance(image_data, torch.Tensor):
  76. image_feature_size = image_data.shape[0]
  77. else:
  78. raise TypeError(f"Invalid image type: {type(image_data)}")
  79. else:
  80. image_feature_size = image_feature_size_override
  81. new_prompt, new_token_ids = repeat_and_pad_image_tokens(
  82. tokenizer,
  83. llm_inputs.get("prompt"),
  84. llm_inputs["prompt_token_ids"],
  85. image_token_id=image_token_id,
  86. repeat_count=image_feature_size,
  87. )
  88. # NOTE: Create a defensive copy of the original inputs
  89. return LLMInputs(prompt_token_ids=new_token_ids,
  90. prompt=new_prompt,
  91. multi_modal_data=multi_modal_data)
  92. # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
  93. class CLIPVisionEmbeddings(nn.Module):
  94. def __init__(self, config: CLIPVisionConfig):
  95. super().__init__()
  96. self.config = config
  97. self.embed_dim = config.hidden_size
  98. self.image_size = config.image_size
  99. self.patch_size = config.patch_size
  100. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  101. self.patch_embedding = nn.Conv2d(
  102. in_channels=config.num_channels,
  103. out_channels=self.embed_dim,
  104. kernel_size=self.patch_size,
  105. stride=self.patch_size,
  106. bias=False,
  107. )
  108. self.num_patches = get_clip_num_patches(image_size=self.image_size,
  109. patch_size=self.patch_size)
  110. self.num_positions = self.num_patches + 1
  111. self.position_embedding = nn.Embedding(self.num_positions,
  112. self.embed_dim)
  113. self.register_buffer("position_ids",
  114. torch.arange(self.num_positions).expand((1, -1)),
  115. persistent=False)
  116. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  117. batch_size = pixel_values.shape[0]
  118. target_dtype = self.patch_embedding.weight.dtype
  119. patch_embeds = self.patch_embedding(pixel_values.to(
  120. dtype=target_dtype)) # shape = [*, width, grid, grid]
  121. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  122. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  123. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  124. embeddings = embeddings + self.position_embedding(self.position_ids)
  125. return embeddings
  126. class CLIPMLP(nn.Module):
  127. def __init__(self,
  128. config: CLIPVisionConfig,
  129. quant_config: Optional[QuantizationConfig] = None):
  130. super().__init__()
  131. self.config = config
  132. self.activation_fn = get_act_fn(config.hidden_act)
  133. self.fc1 = ColumnParallelLinear(config.hidden_size,
  134. config.intermediate_size,
  135. bias=True,
  136. quant_config=quant_config)
  137. self.fc2 = RowParallelLinear(config.intermediate_size,
  138. config.hidden_size,
  139. bias=True,
  140. quant_config=quant_config)
  141. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  142. hidden_states, _ = self.fc1(hidden_states)
  143. hidden_states = self.activation_fn(hidden_states)
  144. hidden_states, _ = self.fc2(hidden_states)
  145. return hidden_states
  146. class CLIPEncoderLayer(nn.Module):
  147. def __init__(self,
  148. config: CLIPVisionConfig,
  149. quant_config: Optional[QuantizationConfig] = None):
  150. super().__init__()
  151. self.self_attn = CLIPAttention(config)
  152. self.layer_norm1 = nn.LayerNorm(config.hidden_size,
  153. eps=config.layer_norm_eps)
  154. self.mlp = CLIPMLP(config, quant_config=quant_config)
  155. self.layer_norm2 = nn.LayerNorm(config.hidden_size,
  156. eps=config.layer_norm_eps)
  157. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  158. residual = hidden_states
  159. hidden_states = self.layer_norm1(hidden_states)
  160. hidden_states, _ = self.self_attn(hidden_states=hidden_states)
  161. hidden_states = residual + hidden_states
  162. residual = hidden_states
  163. hidden_states = self.layer_norm2(hidden_states)
  164. hidden_states = self.mlp(hidden_states)
  165. hidden_states = residual + hidden_states
  166. return hidden_states
  167. class CLIPEncoder(nn.Module):
  168. """
  169. Transformer encoder consisting of `config.num_hidden_layers` self
  170. attention layers. Each layer is a [`CLIPEncoderLayer`].
  171. Args:
  172. config: CLIPConfig
  173. """
  174. def __init__(self,
  175. config: CLIPVisionConfig,
  176. quant_config: Optional[QuantizationConfig] = None,
  177. num_hidden_layers_override: Optional[int] = None):
  178. super().__init__()
  179. self.config = config
  180. if num_hidden_layers_override is None:
  181. num_hidden_layers = config.num_hidden_layers
  182. else:
  183. num_hidden_layers = num_hidden_layers_override
  184. self.layers = nn.ModuleList([
  185. CLIPEncoderLayer(config=config, quant_config=quant_config)
  186. for _ in range(num_hidden_layers)
  187. ])
  188. def forward(self, inputs_embeds: torch.Tensor):
  189. hidden_states = inputs_embeds
  190. for encoder_layer in self.layers:
  191. hidden_states = encoder_layer(hidden_states)
  192. return hidden_states
  193. class CLIPVisionTransformer(nn.Module):
  194. def __init__(self,
  195. config: CLIPVisionConfig,
  196. quant_config: Optional[QuantizationConfig] = None,
  197. num_hidden_layers_override: Optional[int] = None):
  198. super().__init__()
  199. self.config = config
  200. embed_dim = config.hidden_size
  201. self.embeddings = CLIPVisionEmbeddings(config)
  202. # NOTE: This typo of "layrnorm" is not fixed on purpose to match
  203. # the original transformers code and name of the model weights.
  204. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  205. self.encoder = CLIPEncoder(
  206. config=config,
  207. quant_config=quant_config,
  208. num_hidden_layers_override=num_hidden_layers_override)
  209. def forward(
  210. self,
  211. pixel_values: torch.Tensor,
  212. ) -> torch.Tensor:
  213. hidden_states = self.embeddings(pixel_values)
  214. hidden_states = self.pre_layrnorm(hidden_states)
  215. hidden_states = self.encoder(inputs_embeds=hidden_states)
  216. return hidden_states
  217. class CLIPVisionModel(nn.Module):
  218. config_class = CLIPVisionConfig
  219. main_input_name = "pixel_values"
  220. def __init__(self,
  221. config: CLIPVisionConfig,
  222. quant_config: Optional[QuantizationConfig] = None,
  223. num_hidden_layers_override: Optional[int] = None):
  224. super().__init__()
  225. self.vision_model = CLIPVisionTransformer(
  226. config=config,
  227. quant_config=quant_config,
  228. num_hidden_layers_override=num_hidden_layers_override)
  229. def forward(self, pixel_values: Optional[torch.Tensor] = None):
  230. return self.vision_model(pixel_values=pixel_values)
  231. @property
  232. def device(self):
  233. return next(self.parameters()).device